Emaad commited on
Commit
3d993f1
1 Parent(s): e9c0dec

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +4 -1
prediction.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  os.chdir('..')
 
3
  from dataloader import CellLoader
4
  from celle_main import instantiate_from_config
5
  from omegaconf import OmegaConf
@@ -46,7 +47,9 @@ def run_image_prediction(
46
  config["model"]["params"]["vqgan_model_path"] = None
47
 
48
  # Instantiate model from config and move to device
49
- model = instantiate_from_config(config).to(device)
 
 
50
 
51
  # Sample from model using provided sequence and nucleus image
52
  _, _, _, predicted_threshold, predicted_heatmap = model.celle.sample(
 
1
  import os
2
  os.chdir('..')
3
+ base_dir = os.getcwd()
4
  from dataloader import CellLoader
5
  from celle_main import instantiate_from_config
6
  from omegaconf import OmegaConf
 
47
  config["model"]["params"]["vqgan_model_path"] = None
48
 
49
  # Instantiate model from config and move to device
50
+ model = instantiate_from_config(config.model).to(device)
51
+
52
+ os.chdir(base_dir)
53
 
54
  # Sample from model using provided sequence and nucleus image
55
  _, _, _, predicted_threshold, predicted_heatmap = model.celle.sample(