updated
Browse files- training/evaluate.py +4 -2
- training/s2f_trainer.py +2 -1
- training/train.py +2 -1
training/evaluate.py
CHANGED
|
@@ -31,7 +31,8 @@ def main():
|
|
| 31 |
args = parser.parse_args()
|
| 32 |
|
| 33 |
from data.cell_dataset import load_folder_data
|
| 34 |
-
from models.s2f_model import create_s2f_model
|
|
|
|
| 35 |
from utils.metrics import (
|
| 36 |
evaluate_metrics_on_dataset,
|
| 37 |
print_metrics_report,
|
|
@@ -54,7 +55,8 @@ def main():
|
|
| 54 |
)
|
| 55 |
|
| 56 |
in_channels = 3 if use_settings else 1
|
| 57 |
-
|
|
|
|
| 58 |
ckpt = torch.load(args.checkpoint, map_location='cpu', weights_only=False)
|
| 59 |
generator.load_state_dict(ckpt.get('generator_state_dict', ckpt), strict=True)
|
| 60 |
|
|
|
|
| 31 |
args = parser.parse_args()
|
| 32 |
|
| 33 |
from data.cell_dataset import load_folder_data
|
| 34 |
+
from models.s2f_model import create_s2f_model
|
| 35 |
+
from utils.substrate_settings import compute_settings_normalization
|
| 36 |
from utils.metrics import (
|
| 37 |
evaluate_metrics_on_dataset,
|
| 38 |
print_metrics_report,
|
|
|
|
| 55 |
)
|
| 56 |
|
| 57 |
in_channels = 3 if use_settings else 1
|
| 58 |
+
model_type = 's2f' if use_settings else 's2f_spheroid'
|
| 59 |
+
generator, _ = create_s2f_model(in_channels=in_channels, model_type=model_type)
|
| 60 |
ckpt = torch.load(args.checkpoint, map_location='cpu', weights_only=False)
|
| 61 |
generator.load_state_dict(ckpt.get('generator_state_dict', ckpt), strict=True)
|
| 62 |
|
training/s2f_trainer.py
CHANGED
|
@@ -14,7 +14,8 @@ S2F_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
| 14 |
if S2F_ROOT not in sys.path:
|
| 15 |
sys.path.insert(0, S2F_ROOT)
|
| 16 |
|
| 17 |
-
from models.s2f_model import create_settings_channels
|
|
|
|
| 18 |
from utils.metrics import calculate_psnr, calculate_ssim_tensor, calculate_pearson_correlation
|
| 19 |
from scipy.stats import pearsonr
|
| 20 |
|
|
|
|
| 14 |
if S2F_ROOT not in sys.path:
|
| 15 |
sys.path.insert(0, S2F_ROOT)
|
| 16 |
|
| 17 |
+
from models.s2f_model import create_settings_channels
|
| 18 |
+
from utils.substrate_settings import compute_settings_normalization
|
| 19 |
from utils.metrics import calculate_psnr, calculate_ssim_tensor, calculate_pearson_correlation
|
| 20 |
from scipy.stats import pearsonr
|
| 21 |
|
training/train.py
CHANGED
|
@@ -54,7 +54,8 @@ def main():
|
|
| 54 |
)
|
| 55 |
|
| 56 |
in_channels = 3 if use_settings else 1
|
| 57 |
-
|
|
|
|
| 58 |
|
| 59 |
if args.resume:
|
| 60 |
ckpt = __import__('torch').load(args.resume, map_location='cpu', weights_only=False)
|
|
|
|
| 54 |
)
|
| 55 |
|
| 56 |
in_channels = 3 if use_settings else 1
|
| 57 |
+
model_type = 's2f' if use_settings else 's2f_spheroid'
|
| 58 |
+
generator, discriminator = create_s2f_model(in_channels=in_channels, model_type=model_type)
|
| 59 |
|
| 60 |
if args.resume:
|
| 61 |
ckpt = __import__('torch').load(args.resume, map_location='cpu', weights_only=False)
|