kaveh commited on
Commit
54e160a
·
1 Parent(s): 78b9263
training/evaluate.py CHANGED
@@ -31,7 +31,8 @@ def main():
31
  args = parser.parse_args()
32
 
33
  from data.cell_dataset import load_folder_data
34
- from models.s2f_model import create_s2f_model, compute_settings_normalization
 
35
  from utils.metrics import (
36
  evaluate_metrics_on_dataset,
37
  print_metrics_report,
@@ -54,7 +55,8 @@ def main():
54
  )
55
 
56
  in_channels = 3 if use_settings else 1
57
- generator, _ = create_s2f_model(in_channels=in_channels)
 
58
  ckpt = torch.load(args.checkpoint, map_location='cpu', weights_only=False)
59
  generator.load_state_dict(ckpt.get('generator_state_dict', ckpt), strict=True)
60
 
 
31
  args = parser.parse_args()
32
 
33
  from data.cell_dataset import load_folder_data
34
+ from models.s2f_model import create_s2f_model
35
+ from utils.substrate_settings import compute_settings_normalization
36
  from utils.metrics import (
37
  evaluate_metrics_on_dataset,
38
  print_metrics_report,
 
55
  )
56
 
57
  in_channels = 3 if use_settings else 1
58
+ model_type = 's2f' if use_settings else 's2f_spheroid'
59
+ generator, _ = create_s2f_model(in_channels=in_channels, model_type=model_type)
60
  ckpt = torch.load(args.checkpoint, map_location='cpu', weights_only=False)
61
  generator.load_state_dict(ckpt.get('generator_state_dict', ckpt), strict=True)
62
 
training/s2f_trainer.py CHANGED
@@ -14,7 +14,8 @@ S2F_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
14
  if S2F_ROOT not in sys.path:
15
  sys.path.insert(0, S2F_ROOT)
16
 
17
- from models.s2f_model import create_settings_channels, compute_settings_normalization
 
18
  from utils.metrics import calculate_psnr, calculate_ssim_tensor, calculate_pearson_correlation
19
  from scipy.stats import pearsonr
20
 
 
14
  if S2F_ROOT not in sys.path:
15
  sys.path.insert(0, S2F_ROOT)
16
 
17
+ from models.s2f_model import create_settings_channels
18
+ from utils.substrate_settings import compute_settings_normalization
19
  from utils.metrics import calculate_psnr, calculate_ssim_tensor, calculate_pearson_correlation
20
  from scipy.stats import pearsonr
21
 
training/train.py CHANGED
@@ -54,7 +54,8 @@ def main():
54
  )
55
 
56
  in_channels = 3 if use_settings else 1
57
- generator, discriminator = create_s2f_model(in_channels=in_channels)
 
58
 
59
  if args.resume:
60
  ckpt = __import__('torch').load(args.resume, map_location='cpu', weights_only=False)
 
54
  )
55
 
56
  in_channels = 3 if use_settings else 1
57
+ model_type = 's2f' if use_settings else 's2f_spheroid'
58
+ generator, discriminator = create_s2f_model(in_channels=in_channels, model_type=model_type)
59
 
60
  if args.resume:
61
  ckpt = __import__('torch').load(args.resume, map_location='cpu', weights_only=False)