VisionLanguageGroup commited on
Commit
102cd7d
·
1 Parent(s): 25ea9b7
models/seg_post_model/cellpose/__init__.py CHANGED
@@ -1 +1 @@
1
- from .version import version, version_str
 
1
+ # from .version import version, version_str
models/seg_post_model/cellpose/__main__.py DELETED
@@ -1,272 +0,0 @@
1
- """
2
- Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
- """
4
- import os, time
5
- import numpy as np
6
- from tqdm import tqdm
7
- from cellpose import utils, models, io, train
8
- from .version import version_str
9
- from cellpose.cli import get_arg_parser
10
-
11
- try:
12
- from cellpose.gui import gui3d, gui
13
- GUI_ENABLED = True
14
- except ImportError as err:
15
- GUI_ERROR = err
16
- GUI_ENABLED = False
17
- GUI_IMPORT = True
18
- except Exception as err:
19
- GUI_ENABLED = False
20
- GUI_ERROR = err
21
- GUI_IMPORT = False
22
- raise
23
-
24
- import logging
25
-
26
-
27
- def main():
28
- """ Run cellpose from command line
29
- """
30
-
31
- args = get_arg_parser().parse_args() # this has to be in a separate file for autodoc to work
32
-
33
- if args.version:
34
- print(version_str)
35
- return
36
-
37
- ######## if no image arguments are provided, run GUI or add model and exit ########
38
- if len(args.dir) == 0 and len(args.image_path) == 0:
39
- if args.add_model:
40
- io.add_model(args.add_model)
41
- return
42
- else:
43
- if not GUI_ENABLED:
44
- print("GUI ERROR: %s" % GUI_ERROR)
45
- if GUI_IMPORT:
46
- print(
47
- "GUI FAILED: GUI dependencies may not be installed, to install, run"
48
- )
49
- print(" pip install 'cellpose[gui]'")
50
- else:
51
- if args.Zstack:
52
- gui3d.run()
53
- else:
54
- gui.run()
55
- return
56
-
57
- ############################## run cellpose on images ##############################
58
- if args.verbose:
59
- from .io import logger_setup
60
- logger, log_file = logger_setup()
61
- else:
62
- print(
63
- ">>>> !LOGGING OFF BY DEFAULT! To see cellpose progress, set --verbose")
64
- print("No --verbose => no progress or info printed")
65
- logger = logging.getLogger(__name__)
66
-
67
-
68
- # find images
69
- if len(args.img_filter) > 0:
70
- image_filter = args.img_filter
71
- else:
72
- image_filter = None
73
-
74
- device, gpu = models.assign_device(use_torch=True, gpu=args.use_gpu,
75
- device=args.gpu_device)
76
-
77
- if args.pretrained_model is None or args.pretrained_model == "None" or args.pretrained_model == "False" or args.pretrained_model == "0":
78
- pretrained_model = "cpsam"
79
- logger.warning("training from scratch is disabled, using 'cpsam' model")
80
- else:
81
- pretrained_model = args.pretrained_model
82
-
83
- # Warn users about old arguments from CP3:
84
- if args.pretrained_model_ortho:
85
- logger.warning(
86
- "the '--pretrained_model_ortho' flag is deprecated in v4.0.1+ and no longer used")
87
- if args.train_size:
88
- logger.warning("the '--train_size' flag is deprecated in v4.0.1+ and no longer used")
89
- if args.chan or args.chan2:
90
- logger.warning('--chan and --chan2 are deprecated, all channels are used by default')
91
- if args.all_channels:
92
- logger.warning("the '--all_channels' flag is deprecated in v4.0.1+ and no longer used")
93
- if args.restore_type:
94
- logger.warning("the '--restore_type' flag is deprecated in v4.0.1+ and no longer used")
95
- if args.transformer:
96
- logger.warning("the '--tranformer' flag is deprecated in v4.0.1+ and no longer used")
97
- if args.invert:
98
- logger.warning("the '--invert' flag is deprecated in v4.0.1+ and no longer used")
99
- if args.chan2_restore:
100
- logger.warning("the '--chan2_restore' flag is deprecated in v4.0.1+ and no longer used")
101
- if args.diam_mean:
102
- logger.warning("the '--diam_mean' flag is deprecated in v4.0.1+ and no longer used")
103
- if args.train_size:
104
- logger.warning("the '--train_size' flag is deprecated in v4.0.1+ and no longer used")
105
-
106
- if args.norm_percentile is not None:
107
- value1, value2 = args.norm_percentile
108
- normalize = {'percentile': (float(value1), float(value2))}
109
- else:
110
- normalize = (not args.no_norm)
111
-
112
- if args.save_each:
113
- if not args.save_every:
114
- raise ValueError("ERROR: --save_each requires --save_every")
115
-
116
- if len(args.image_path) > 0 and args.train:
117
- raise ValueError("ERROR: cannot train model with single image input")
118
-
119
- ## Run evaluation on images
120
- if not args.train:
121
- _evaluate_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize)
122
-
123
- ## Train a model ##
124
- else:
125
- _train_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize)
126
-
127
-
128
- def _train_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize):
129
- test_dir = None if len(args.test_dir) == 0 else args.test_dir
130
- images, labels, image_names, train_probs = None, None, None, None
131
- test_images, test_labels, image_names_test, test_probs = None, None, None, None
132
- compute_flows = False
133
- if len(args.file_list) > 0:
134
- if os.path.exists(args.file_list):
135
- dat = np.load(args.file_list, allow_pickle=True).item()
136
- image_names = dat["train_files"]
137
- image_names_test = dat.get("test_files", None)
138
- train_probs = dat.get("train_probs", None)
139
- test_probs = dat.get("test_probs", None)
140
- compute_flows = dat.get("compute_flows", False)
141
- load_files = False
142
- else:
143
- logger.critical(f"ERROR: {args.file_list} does not exist")
144
- else:
145
- output = io.load_train_test_data(args.dir, test_dir, image_filter,
146
- args.mask_filter,
147
- args.look_one_level_down)
148
- images, labels, image_names, test_images, test_labels, image_names_test = output
149
- load_files = True
150
-
151
- # initialize model
152
- model = models.CellposeModel(device=device, pretrained_model=pretrained_model)
153
-
154
- # train segmentation model
155
- cpmodel_path = train.train_seg(
156
- model.net, images, labels, train_files=image_names,
157
- test_data=test_images, test_labels=test_labels,
158
- test_files=image_names_test, train_probs=train_probs,
159
- test_probs=test_probs, compute_flows=compute_flows,
160
- load_files=load_files, normalize=normalize,
161
- channel_axis=args.channel_axis,
162
- learning_rate=args.learning_rate, weight_decay=args.weight_decay,
163
- SGD=args.SGD, n_epochs=args.n_epochs, batch_size=args.train_batch_size,
164
- min_train_masks=args.min_train_masks,
165
- nimg_per_epoch=args.nimg_per_epoch,
166
- nimg_test_per_epoch=args.nimg_test_per_epoch,
167
- save_path=os.path.realpath(args.dir),
168
- save_every=args.save_every,
169
- save_each=args.save_each,
170
- model_name=args.model_name_out)[0]
171
- model.pretrained_model = cpmodel_path
172
- logger.info(">>>> model trained and saved to %s" % cpmodel_path)
173
- return model
174
-
175
-
176
- def _evaluate_cellposemodel_cli(args, logger, imf, device, pretrained_model, normalize):
177
- # Check with user if they REALLY mean to run without saving anything
178
- if not args.train:
179
- saving_something = args.save_png or args.save_tif or args.save_flows or args.save_txt
180
-
181
- tic = time.time()
182
- if len(args.dir) > 0:
183
- image_names = io.get_image_files(
184
- args.dir, args.mask_filter, imf=imf,
185
- look_one_level_down=args.look_one_level_down)
186
- else:
187
- if os.path.exists(args.image_path):
188
- image_names = [args.image_path]
189
- else:
190
- raise ValueError(f"ERROR: no file found at {args.image_path}")
191
- nimg = len(image_names)
192
-
193
- if args.savedir:
194
- if not os.path.exists(args.savedir):
195
- raise FileExistsError(f"--savedir {args.savedir} does not exist")
196
-
197
- logger.info(
198
- ">>>> running cellpose on %d images using all channels" % nimg)
199
-
200
- # handle built-in model exceptions
201
- model = models.CellposeModel(device=device, pretrained_model=pretrained_model,)
202
-
203
- tqdm_out = utils.TqdmToLogger(logger, level=logging.INFO)
204
-
205
- channel_axis = args.channel_axis
206
- z_axis = args.z_axis
207
-
208
- for image_name in tqdm(image_names, file=tqdm_out):
209
- if args.do_3D or args.stitch_threshold > 0.:
210
- logger.info('loading image as 3D zstack')
211
- image = io.imread_3D(image_name)
212
- if channel_axis is None:
213
- channel_axis = 3
214
- if z_axis is None:
215
- z_axis = 0
216
-
217
- else:
218
- image = io.imread_2D(image_name)
219
- out = model.eval(
220
- image,
221
- diameter=args.diameter,
222
- do_3D=args.do_3D,
223
- augment=args.augment,
224
- flow_threshold=args.flow_threshold,
225
- cellprob_threshold=args.cellprob_threshold,
226
- stitch_threshold=args.stitch_threshold,
227
- min_size=args.min_size,
228
- batch_size=args.batch_size,
229
- bsize=args.bsize,
230
- resample=not args.no_resample,
231
- normalize=normalize,
232
- channel_axis=channel_axis,
233
- z_axis=z_axis,
234
- anisotropy=args.anisotropy,
235
- niter=args.niter,
236
- flow3D_smooth=args.flow3D_smooth)
237
- masks, flows = out[:2]
238
-
239
- if args.exclude_on_edges:
240
- masks = utils.remove_edge_masks(masks)
241
- if not args.no_npy:
242
- io.masks_flows_to_seg(image, masks, flows, image_name,
243
- imgs_restore=None,
244
- restore_type=None,
245
- ratio=1.)
246
- if saving_something:
247
- suffix = "_cp_masks"
248
- if args.output_name is not None:
249
- # (1) If `savedir` is not defined, then must have a non-zero `suffix`
250
- if args.savedir is None and len(args.output_name) > 0:
251
- suffix = args.output_name
252
- elif args.savedir is not None and not os.path.samefile(args.savedir, args.dir):
253
- # (2) If `savedir` is defined, and different from `dir` then
254
- # takes the value passed as a param. (which can be empty string)
255
- suffix = args.output_name
256
-
257
- io.save_masks(image, masks, flows, image_name,
258
- suffix=suffix, png=args.save_png,
259
- tif=args.save_tif, save_flows=args.save_flows,
260
- save_outlines=args.save_outlines,
261
- dir_above=args.dir_above, savedir=args.savedir,
262
- save_txt=args.save_txt, in_folders=args.in_folders,
263
- save_mpl=args.save_mpl)
264
- if args.save_rois:
265
- io.save_rois(masks, image_name)
266
- logger.info(">>>> completed in %0.3f sec" % (time.time() - tic))
267
-
268
- return model
269
-
270
-
271
- if __name__ == "__main__":
272
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/seg_post_model/cellpose/cli.py DELETED
@@ -1,240 +0,0 @@
1
- """
2
- Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu and Michael Rariden.
3
- """
4
-
5
- import argparse
6
-
7
-
8
- def get_arg_parser():
9
- """ Parses command line arguments for cellpose main function
10
-
11
- Note: this function has to be in a separate file to allow autodoc to work for CLI.
12
- The autodoc_mock_imports in conf.py does not work for sphinx-argparse sometimes,
13
- see https://github.com/ashb/sphinx-argparse/issues/9#issue-1097057823
14
- """
15
-
16
- parser = argparse.ArgumentParser(description="Cellpose Command Line Parameters")
17
-
18
- # misc settings
19
- parser.add_argument("--version", action="store_true",
20
- help="show cellpose version info")
21
- parser.add_argument(
22
- "--verbose", action="store_true",
23
- help="show information about running and settings and save to log")
24
- parser.add_argument("--Zstack", action="store_true", help="run GUI in 3D mode")
25
-
26
- # settings for CPU vs GPU
27
- hardware_args = parser.add_argument_group("Hardware Arguments")
28
- hardware_args.add_argument("--use_gpu", action="store_true",
29
- help="use gpu if torch with cuda installed")
30
- hardware_args.add_argument(
31
- "--gpu_device", required=False, default="0", type=str,
32
- help="which gpu device to use, use an integer for torch, or mps for M1")
33
-
34
- # settings for locating and formatting images
35
- input_img_args = parser.add_argument_group("Input Image Arguments")
36
- input_img_args.add_argument("--dir", default=[], type=str,
37
- help="folder containing data to run or train on.")
38
- input_img_args.add_argument(
39
- "--image_path", default=[], type=str, help=
40
- "if given and --dir not given, run on single image instead of folder (cannot train with this option)"
41
- )
42
- input_img_args.add_argument(
43
- "--look_one_level_down", action="store_true",
44
- help="run processing on all subdirectories of current folder")
45
- input_img_args.add_argument("--img_filter", default=[], type=str,
46
- help="end string for images to run on")
47
- input_img_args.add_argument(
48
- "--channel_axis", default=None, type=int,
49
- help="axis of image which corresponds to image channels")
50
- input_img_args.add_argument("--z_axis", default=None, type=int,
51
- help="axis of image which corresponds to Z dimension")
52
-
53
- # TODO: remove deprecated in future version
54
- input_img_args.add_argument(
55
- "--chan", default=0, type=int, help=
56
- "Deprecated in v4.0.1+, not used. ")
57
- input_img_args.add_argument(
58
- "--chan2", default=0, type=int, help=
59
- 'Deprecated in v4.0.1+, not used. ')
60
- input_img_args.add_argument("--invert", action="store_true", help=
61
- 'Deprecated in v4.0.1+, not used. ')
62
- input_img_args.add_argument(
63
- "--all_channels", action="store_true", help=
64
- 'Deprecated in v4.0.1+, not used. ')
65
-
66
- # model settings
67
- model_args = parser.add_argument_group("Model Arguments")
68
- model_args.add_argument("--pretrained_model", required=False, default="cpsam",
69
- type=str,
70
- help="model to use for running or starting training")
71
- model_args.add_argument(
72
- "--add_model", required=False, default=None, type=str,
73
- help="model path to copy model to hidden .cellpose folder for using in GUI/CLI")
74
- model_args.add_argument("--pretrained_model_ortho", required=False, default=None,
75
- type=str,
76
- help="Deprecated in v4.0.1+, not used. ")
77
-
78
- # TODO: remove deprecated in future version
79
- model_args.add_argument("--restore_type", required=False, default=None, type=str, help=
80
- 'Deprecated in v4.0.1+, not used. ')
81
- model_args.add_argument("--chan2_restore", action="store_true", help=
82
- 'Deprecated in v4.0.1+, not used. ')
83
- model_args.add_argument(
84
- "--transformer", action="store_true", help=
85
- "use transformer backbone (pretrained_model from Cellpose3 is transformer_cp3)")
86
-
87
- # algorithm settings
88
- algorithm_args = parser.add_argument_group("Algorithm Arguments")
89
- algorithm_args.add_argument("--no_norm", action="store_true",
90
- help="do not normalize images (normalize=False)")
91
- algorithm_args.add_argument(
92
- '--norm_percentile',
93
- nargs=2, # Require exactly two values
94
- metavar=('VALUE1', 'VALUE2'),
95
- help="Provide two float values to set norm_percentile (e.g., --norm_percentile 1 99)"
96
- )
97
- algorithm_args.add_argument(
98
- "--do_3D", action="store_true",
99
- help="process images as 3D stacks of images (nplanes x nchan x Ly x Lx")
100
- algorithm_args.add_argument(
101
- "--diameter", required=False, default=None, type=float, help=
102
- "use to resize cells to the training diameter (30 pixels)"
103
- )
104
- algorithm_args.add_argument(
105
- "--stitch_threshold", required=False, default=0.0, type=float,
106
- help="compute masks in 2D then stitch together masks with IoU>0.9 across planes"
107
- )
108
- algorithm_args.add_argument(
109
- "--min_size", required=False, default=15, type=int,
110
- help="minimum number of pixels per mask, can turn off with -1")
111
- algorithm_args.add_argument(
112
- "--flow3D_smooth", required=False, default=0, type=float,
113
- help="stddev of gaussian for smoothing of dP for dynamics in 3D, default of 0 means no smoothing")
114
- algorithm_args.add_argument(
115
- "--flow_threshold", default=0.4, type=float, help=
116
- "flow error threshold, 0 turns off this optional QC step. Default: %(default)s")
117
- algorithm_args.add_argument(
118
- "--cellprob_threshold", default=0, type=float,
119
- help="cellprob threshold, default is 0, decrease to find more and larger masks")
120
- algorithm_args.add_argument(
121
- "--niter", default=0, type=int, help=
122
- "niter, number of iterations for dynamics for mask creation, default of 0 means it is proportional to diameter, set to a larger number like 2000 for very long ROIs"
123
- )
124
- algorithm_args.add_argument("--anisotropy", required=False, default=1.0, type=float,
125
- help="anisotropy of volume in 3D")
126
- algorithm_args.add_argument("--exclude_on_edges", action="store_true",
127
- help="discard masks which touch edges of image")
128
- algorithm_args.add_argument(
129
- "--augment", action="store_true",
130
- help="tiles image with overlapping tiles and flips overlapped regions to augment"
131
- )
132
- algorithm_args.add_argument("--batch_size", default=8, type=int,
133
- help="inference batch size. Default: %(default)s")
134
-
135
- # TODO: remove deprecated in future version
136
- algorithm_args.add_argument(
137
- "--no_resample", action="store_true",
138
- help="disables flows/cellprob resampling to original image size before computing masks. Using this flag will make more masks more jagged with larger diameter settings.")
139
- algorithm_args.add_argument(
140
- "--no_interp", action="store_true",
141
- help="do not interpolate when running dynamics (was default)")
142
-
143
- # output settings
144
- output_args = parser.add_argument_group("Output Arguments")
145
- output_args.add_argument(
146
- "--save_png", action="store_true",
147
- help="save masks as png")
148
- output_args.add_argument(
149
- "--save_tif", action="store_true",
150
- help="save masks as tif")
151
- output_args.add_argument(
152
- "--output_name", default=None, type=str,
153
- help="suffix for saved masks, default is _cp_masks, can be empty if `savedir` used and different of `dir`")
154
- output_args.add_argument("--no_npy", action="store_true",
155
- help="suppress saving of npy")
156
- output_args.add_argument(
157
- "--savedir", default=None, type=str, help=
158
- "folder to which segmentation results will be saved (defaults to input image directory)"
159
- )
160
- output_args.add_argument(
161
- "--dir_above", action="store_true", help=
162
- "save output folders adjacent to image folder instead of inside it (off by default)"
163
- )
164
- output_args.add_argument("--in_folders", action="store_true",
165
- help="flag to save output in folders (off by default)")
166
- output_args.add_argument(
167
- "--save_flows", action="store_true", help=
168
- "whether or not to save RGB images of flows when masks are saved (disabled by default)"
169
- )
170
- output_args.add_argument(
171
- "--save_outlines", action="store_true", help=
172
- "whether or not to save RGB outline images when masks are saved (disabled by default)"
173
- )
174
- output_args.add_argument(
175
- "--save_rois", action="store_true",
176
- help="whether or not to save ImageJ compatible ROI archive (disabled by default)"
177
- )
178
- output_args.add_argument(
179
- "--save_txt", action="store_true",
180
- help="flag to enable txt outlines for ImageJ (disabled by default)")
181
- output_args.add_argument(
182
- "--save_mpl", action="store_true",
183
- help="save a figure of image/mask/flows using matplotlib (disabled by default). "
184
- "This is slow, especially with large images.")
185
-
186
- # training settings
187
- training_args = parser.add_argument_group("Training Arguments")
188
- training_args.add_argument("--train", action="store_true",
189
- help="train network using images in dir")
190
- training_args.add_argument("--test_dir", default=[], type=str,
191
- help="folder containing test data (optional)")
192
- training_args.add_argument(
193
- "--file_list", default=[], type=str, help=
194
- "path to list of files for training and testing and probabilities for each image (optional)"
195
- )
196
- training_args.add_argument(
197
- "--mask_filter", default="_masks", type=str, help=
198
- "end string for masks to run on. use '_seg.npy' for manual annotations from the GUI. Default: %(default)s"
199
- )
200
- training_args.add_argument("--learning_rate", default=1e-5, type=float,
201
- help="learning rate. Default: %(default)s")
202
- training_args.add_argument("--weight_decay", default=0.1, type=float,
203
- help="weight decay. Default: %(default)s")
204
- training_args.add_argument("--n_epochs", default=100, type=int,
205
- help="number of epochs. Default: %(default)s")
206
- training_args.add_argument("--train_batch_size", default=1, type=int,
207
- help="training batch size. Default: %(default)s")
208
- training_args.add_argument("--bsize", default=256, type=int,
209
- help="block size for tiles. Default: %(default)s")
210
- training_args.add_argument(
211
- "--nimg_per_epoch", default=None, type=int,
212
- help="number of train images per epoch. Default is to use all train images.")
213
- training_args.add_argument(
214
- "--nimg_test_per_epoch", default=None, type=int,
215
- help="number of test images per epoch. Default is to use all test images.")
216
- training_args.add_argument(
217
- "--min_train_masks", default=5, type=int, help=
218
- "minimum number of masks a training image must have to be used. Default: %(default)s"
219
- )
220
- training_args.add_argument("--SGD", default=0, type=int,
221
- help="Deprecated in v4.0.1+, not used - AdamW used instead. ")
222
- training_args.add_argument(
223
- "--save_every", default=100, type=int,
224
- help="number of epochs to skip between saves. Default: %(default)s")
225
- training_args.add_argument(
226
- "--save_each", action="store_true",
227
- help="wether or not to save each epoch. Must also use --save_every. (default: False)")
228
- training_args.add_argument(
229
- "--model_name_out", default=None, type=str,
230
- help="Name of model to save as, defaults to name describing model architecture. "
231
- "Model is saved in the folder specified by --dir in models subfolder.")
232
-
233
- # TODO: remove deprecated in future version
234
- training_args.add_argument(
235
- "--diam_mean", default=30., type=float, help=
236
- 'Deprecated in v4.0.1+, not used. ')
237
- training_args.add_argument("--train_size", action="store_true", help=
238
- 'Deprecated in v4.0.1+, not used. ')
239
-
240
- return parser
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/seg_post_model/cellpose/denoise.py DELETED
@@ -1,1474 +0,0 @@
1
- """
2
- Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
- """
4
- import os, time, datetime
5
- import numpy as np
6
- from scipy.stats import mode
7
- import cv2
8
- import torch
9
- from torch import nn
10
- from torch.nn.functional import conv2d, interpolate
11
- from tqdm import trange
12
- from pathlib import Path
13
-
14
- import logging
15
-
16
- denoise_logger = logging.getLogger(__name__)
17
-
18
- from cellpose import transforms, utils, io
19
- from cellpose.core import run_net
20
- from cellpose.models import CellposeModel, model_path, normalize_default, assign_device
21
-
22
- MODEL_NAMES = []
23
- for ctype in ["cyto3", "cyto2", "nuclei"]:
24
- for ntype in ["denoise", "deblur", "upsample", "oneclick"]:
25
- MODEL_NAMES.append(f"{ntype}_{ctype}")
26
- if ctype != "cyto3":
27
- for ltype in ["per", "seg", "rec"]:
28
- MODEL_NAMES.append(f"{ntype}_{ltype}_{ctype}")
29
- if ctype != "cyto3":
30
- MODEL_NAMES.append(f"aniso_{ctype}")
31
-
32
- criterion = nn.MSELoss(reduction="mean")
33
- criterion2 = nn.BCEWithLogitsLoss(reduction="mean")
34
-
35
-
36
- def deterministic(seed=0):
37
- """ set random seeds to create test data """
38
- import random
39
- torch.manual_seed(seed)
40
- torch.cuda.manual_seed(seed)
41
- torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
42
- np.random.seed(seed) # Numpy module.
43
- random.seed(seed) # Python random module.
44
- torch.manual_seed(seed)
45
- torch.backends.cudnn.benchmark = False
46
- torch.backends.cudnn.deterministic = True
47
-
48
-
49
- def loss_fn_rec(lbl, y):
50
- """ loss function between true labels lbl and prediction y """
51
- loss = 80. * criterion(y, lbl)
52
- return loss
53
-
54
-
55
- def loss_fn_seg(lbl, y):
56
- """ loss function between true labels lbl and prediction y """
57
- veci = 5. * lbl[:, 1:]
58
- lbl = (lbl[:, 0] > .5).float()
59
- loss = criterion(y[:, :2], veci)
60
- loss /= 2.
61
- loss2 = criterion2(y[:, 2], lbl)
62
- loss = loss + loss2
63
- return loss
64
-
65
-
66
- def get_sigma(Tdown):
67
- """ Calculates the correlation matrices across channels for the perceptual loss.
68
-
69
- Args:
70
- Tdown (list): List of tensors output by each downsampling block of network.
71
-
72
- Returns:
73
- list: List of correlations for each input tensor.
74
- """
75
- Tnorm = [x - x.mean((-2, -1), keepdim=True) for x in Tdown]
76
- Tnorm = [x / x.std((-2, -1), keepdim=True) for x in Tnorm]
77
- Sigma = [
78
- torch.einsum("bnxy, bmxy -> bnm", x, x) / (x.shape[-2] * x.shape[-1])
79
- for x in Tnorm
80
- ]
81
- return Sigma
82
-
83
-
84
- def imstats(X, net1):
85
- """
86
- Calculates the image correlation matrices for the perceptual loss.
87
-
88
- Args:
89
- X (torch.Tensor): Input image tensor.
90
- net1: Cellpose net.
91
-
92
- Returns:
93
- list: A list of tensors of correlation matrices.
94
- """
95
- _, _, Tdown = net1(X)
96
- Sigma = get_sigma(Tdown)
97
- Sigma = [x.detach() for x in Sigma]
98
- return Sigma
99
-
100
-
101
- def loss_fn_per(img, net1, yl):
102
- """
103
- Calculates the perceptual loss function for image restoration.
104
-
105
- Args:
106
- img (torch.Tensor): Input image tensor (noisy/blurry/downsampled).
107
- net1 (torch.nn.Module): Perceptual loss net (Cellpose segmentation net).
108
- yl (torch.Tensor): Clean image tensor.
109
-
110
- Returns:
111
- torch.Tensor: Mean perceptual loss.
112
- """
113
- Sigma = imstats(img, net1)
114
- sd = [x.std((1, 2)) + 1e-6 for x in Sigma]
115
- Sigma_test = get_sigma(yl)
116
- losses = torch.zeros(len(Sigma[0]), device=img.device)
117
- for k in range(len(Sigma)):
118
- losses = losses + (((Sigma_test[k] - Sigma[k])**2).mean((1, 2)) / sd[k]**2)
119
- return losses.mean()
120
-
121
-
122
- def test_loss(net0, X, net1=None, img=None, lbl=None, lam=[1., 1.5, 0.]):
123
- """
124
- Calculates the test loss for image restoration tasks.
125
-
126
- Args:
127
- net0 (torch.nn.Module): The image restoration network.
128
- X (torch.Tensor): The input image tensor.
129
- net1 (torch.nn.Module, optional): The segmentation network for segmentation or perceptual loss. Defaults to None.
130
- img (torch.Tensor, optional): Clean image tensor for perceptual or reconstruction loss. Defaults to None.
131
- lbl (torch.Tensor, optional): The ground truth flows/cellprob tensor for segmentation loss. Defaults to None.
132
- lam (list, optional): The weights for different loss components (perceptual, segmentation, reconstruction). Defaults to [1., 1.5, 0.].
133
-
134
- Returns:
135
- tuple: A tuple containing the total loss and the perceptual loss.
136
- """
137
- net0.eval()
138
- if net1 is not None:
139
- net1.eval()
140
- loss, loss_per = torch.zeros(1, device=X.device), torch.zeros(1, device=X.device)
141
-
142
- with torch.no_grad():
143
- img_dn = net0(X)[0]
144
- if lam[2] > 0.:
145
- loss += lam[2] * loss_fn_rec(img, img_dn)
146
- if lam[1] > 0. or lam[0] > 0.:
147
- y, _, ydown = net1(img_dn)
148
- if lam[1] > 0.:
149
- loss += lam[1] * loss_fn_seg(lbl, y)
150
- if lam[0] > 0.:
151
- loss_per = loss_fn_per(img, net1, ydown)
152
- loss += lam[0] * loss_per
153
- return loss, loss_per
154
-
155
-
156
- def train_loss(net0, X, net1=None, img=None, lbl=None, lam=[1., 1.5, 0.]):
157
- """
158
- Calculates the train loss for image restoration tasks.
159
-
160
- Args:
161
- net0 (torch.nn.Module): The image restoration network.
162
- X (torch.Tensor): The input image tensor.
163
- net1 (torch.nn.Module, optional): The segmentation network for segmentation or perceptual loss. Defaults to None.
164
- img (torch.Tensor, optional): Clean image tensor for perceptual or reconstruction loss. Defaults to None.
165
- lbl (torch.Tensor, optional): The ground truth flows/cellprob tensor for segmentation loss. Defaults to None.
166
- lam (list, optional): The weights for different loss components (perceptual, segmentation, reconstruction). Defaults to [1., 1.5, 0.].
167
-
168
- Returns:
169
- tuple: A tuple containing the total loss and the perceptual loss.
170
- """
171
- net0.train()
172
- if net1 is not None:
173
- net1.eval()
174
- loss, loss_per = torch.zeros(1, device=X.device), torch.zeros(1, device=X.device)
175
-
176
- img_dn = net0(X)[0]
177
- if lam[2] > 0.:
178
- loss += lam[2] * loss_fn_rec(img, img_dn)
179
- if lam[1] > 0. or lam[0] > 0.:
180
- y, _, ydown = net1(img_dn)
181
- if lam[1] > 0.:
182
- loss += lam[1] * loss_fn_seg(lbl, y)
183
- if lam[0] > 0.:
184
- loss_per = loss_fn_per(img, net1, ydown)
185
- loss += lam[0] * loss_per
186
- return loss, loss_per
187
-
188
-
189
- def img_norm(imgi):
190
- """
191
- Normalizes the input image by subtracting the 1st percentile and dividing by the difference between the 99th and 1st percentiles.
192
-
193
- Args:
194
- imgi (torch.Tensor): Input image tensor.
195
-
196
- Returns:
197
- torch.Tensor: Normalized image tensor.
198
- """
199
- shape = imgi.shape
200
- imgi = imgi.reshape(imgi.shape[0], imgi.shape[1], -1)
201
- perc = torch.quantile(imgi, torch.tensor([0.01, 0.99], device=imgi.device), dim=-1,
202
- keepdim=True)
203
- for k in range(imgi.shape[1]):
204
- hask = (perc[1, :, k, 0] - perc[0, :, k, 0]) > 1e-3
205
- imgi[hask, k] -= perc[0, hask, k]
206
- imgi[hask, k] /= (perc[1, hask, k] - perc[0, hask, k])
207
- imgi = imgi.reshape(shape)
208
- return imgi
209
-
210
-
211
- def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsample=0.7,
212
- ds_max=7, diams=None, pscale=None, iso=True, sigma0=None, sigma1=None,
213
- ds=None, uniform_blur=False, partial_blur=False):
214
- """Adds noise to the input image.
215
-
216
- Args:
217
- lbl (torch.Tensor): The input image tensor of shape (nimg, nchan, Ly, Lx).
218
- alpha (float, optional): The shape parameter of the gamma distribution used for generating poisson noise. Defaults to 4.
219
- beta (float, optional): The rate parameter of the gamma distribution used for generating poisson noise. Defaults to 0.7.
220
- poisson (float, optional): The probability of adding poisson noise to the image. Defaults to 0.7.
221
- blur (float, optional): The probability of adding gaussian blur to the image. Defaults to 0.7.
222
- gblur (float, optional): The scale factor for the gaussian blur. Defaults to 1.0.
223
- downsample (float, optional): The probability of downsampling the image. Defaults to 0.7.
224
- ds_max (int, optional): The maximum downsampling factor. Defaults to 7.
225
- diams (torch.Tensor, optional): The diameter of the objects in the image. Defaults to None.
226
- pscale (torch.Tensor, optional): The scale factor for the poisson noise, instead of sampling. Defaults to None.
227
- iso (bool, optional): Whether to use isotropic gaussian blur. Defaults to True.
228
- sigma0 (torch.Tensor, optional): The standard deviation of the gaussian filter for the Y axis, instead of sampling. Defaults to None.
229
- sigma1 (torch.Tensor, optional): The standard deviation of the gaussian filter for the X axis, instead of sampling. Defaults to None.
230
- ds (torch.Tensor, optional): The downsampling factor for each image, instead of sampling. Defaults to None.
231
-
232
- Returns:
233
- torch.Tensor: The noisy image tensor of the same shape as the input image.
234
- """
235
- device = lbl.device
236
- imgi = torch.zeros_like(lbl)
237
- Ly, Lx = lbl.shape[-2:]
238
-
239
- diams = diams if diams is not None else 30. * torch.ones(len(lbl), device=device)
240
- #ds0 = 1 if ds is None else ds.item()
241
- ds = ds * torch.ones(
242
- (len(lbl),), device=device, dtype=torch.long) if ds is not None else ds
243
-
244
- # downsample
245
- ii = []
246
- idownsample = np.random.rand(len(lbl)) < downsample
247
- if (ds is None and idownsample.sum() > 0.) or not iso:
248
- ds = torch.ones(len(lbl), dtype=torch.long, device=device)
249
- ds[idownsample] = torch.randint(2, ds_max + 1, size=(idownsample.sum(),),
250
- device=device)
251
- ii = torch.nonzero(ds > 1).flatten()
252
- elif ds is not None and (ds > 1).sum():
253
- ii = torch.nonzero(ds > 1).flatten()
254
-
255
- # add gaussian blur
256
- iblur = torch.rand(len(lbl), device=device) < blur
257
- iblur[ii] = True
258
- if iblur.sum() > 0:
259
- if sigma0 is None:
260
- if uniform_blur and iso:
261
- xr = torch.rand(len(lbl), device=device)
262
- if len(ii) > 0:
263
- xr[ii] = ds[ii].float() / 2. / gblur
264
- sigma0 = diams[iblur] / 30. * gblur * (1 / gblur + (1 - 1 / gblur) * xr[iblur])
265
- sigma1 = sigma0.clone()
266
- elif not iso:
267
- xr = torch.rand(len(lbl), device=device)
268
- if len(ii) > 0:
269
- xr[ii] = (ds[ii].float()) / gblur
270
- xr[ii] = xr[ii] + torch.rand(len(ii), device=device) * 0.7 - 0.35
271
- xr[ii] = torch.clip(xr[ii], 0.05, 1.5)
272
- sigma0 = diams[iblur] / 30. * gblur * xr[iblur]
273
- sigma1 = sigma0.clone() / 10.
274
- else:
275
- xrand = np.random.exponential(1, size=iblur.sum())
276
- xrand = np.clip(xrand * 0.5, 0.1, 1.0)
277
- xrand *= gblur
278
- sigma0 = diams[iblur] / 30. * 5. * torch.from_numpy(xrand).float().to(
279
- device)
280
- sigma1 = sigma0.clone()
281
- else:
282
- sigma0 = sigma0 * torch.ones((iblur.sum(),), device=device)
283
- sigma1 = sigma1 * torch.ones((iblur.sum(),), device=device)
284
-
285
- # create gaussian filter
286
- xr = max(8, sigma0.max().long() * 2)
287
- gfilt0 = torch.exp(-torch.arange(-xr + 1, xr, device=device)**2 /
288
- (2 * sigma0.unsqueeze(-1)**2))
289
- gfilt0 /= gfilt0.sum(axis=-1, keepdims=True)
290
- gfilt1 = torch.zeros_like(gfilt0)
291
- gfilt1[sigma1 == sigma0] = gfilt0[sigma1 == sigma0]
292
- gfilt1[sigma1 != sigma0] = torch.exp(
293
- -torch.arange(-xr + 1, xr, device=device)**2 /
294
- (2 * sigma1[sigma1 != sigma0].unsqueeze(-1)**2))
295
- gfilt1[sigma1 == 0] = 0.
296
- gfilt1[sigma1 == 0, xr] = 1.
297
- gfilt1 /= gfilt1.sum(axis=-1, keepdims=True)
298
- gfilt = torch.einsum("ck,cl->ckl", gfilt0, gfilt1)
299
- gfilt /= gfilt.sum(axis=(1, 2), keepdims=True)
300
-
301
- lbl_blur = conv2d(lbl[iblur].transpose(1, 0), gfilt.unsqueeze(1),
302
- padding=gfilt.shape[-1] // 2,
303
- groups=gfilt.shape[0]).transpose(1, 0)
304
- if partial_blur:
305
- #yc, xc = np.random.randint(100, Ly-100), np.random.randint(100, Lx-100)
306
- imgi[iblur] = lbl[iblur].clone()
307
- Lxc = int(Lx * 0.85)
308
- ym, xm = torch.meshgrid(torch.zeros(Ly, dtype=torch.float32),
309
- torch.arange(0, Lxc, dtype=torch.float32),
310
- indexing="ij")
311
- mask = torch.exp(-(ym**2 + xm**2) / 2*(0.001**2))
312
- mask -= mask.min()
313
- mask /= mask.max()
314
- lbl_blur_crop = lbl_blur[:, :, :, :Lxc]
315
- imgi[iblur, :, :, :Lxc] = (lbl_blur_crop * mask +
316
- (1-mask) * imgi[iblur, :, :, :Lxc])
317
- else:
318
- imgi[iblur] = lbl_blur
319
-
320
- imgi[~iblur] = lbl[~iblur]
321
-
322
- # apply downsample
323
- for k in ii:
324
- i0 = imgi[k:k + 1, :, ::ds[k], ::ds[k]] if iso else imgi[k:k + 1, :, ::ds[k]]
325
- imgi[k] = interpolate(i0, size=lbl[k].shape[-2:], mode="bilinear")
326
-
327
- # add poisson noise
328
- ipoisson = np.random.rand(len(lbl)) < poisson
329
- if ipoisson.sum() > 0:
330
- if pscale is None:
331
- pscale = torch.zeros(len(lbl))
332
- m = torch.distributions.gamma.Gamma(alpha, beta)
333
- pscale = torch.clamp(m.rsample(sample_shape=(ipoisson.sum(),)), 1.)
334
- #pscale = torch.clamp(20 * (torch.rand(size=(len(lbl),), device=lbl.device)), 1.5)
335
- pscale = pscale.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).to(device)
336
- else:
337
- pscale = pscale * torch.ones((ipoisson.sum(), 1, 1, 1), device=device)
338
- imgi[ipoisson] = torch.poisson(pscale * imgi[ipoisson])
339
- imgi[~ipoisson] = imgi[~ipoisson]
340
-
341
- # renormalize
342
- imgi = img_norm(imgi)
343
-
344
- return imgi
345
-
346
-
347
- def random_rotate_and_resize_noise(data, labels=None, diams=None, poisson=0.7, blur=0.7,
348
- downsample=0.0, beta=0.7, gblur=1.0, diam_mean=30,
349
- ds_max=7, uniform_blur=False, iso=True, rotate=True,
350
- device=torch.device("cuda"), xy=(224, 224),
351
- nchan_noise=1, keep_raw=True):
352
- """
353
- Applies random rotation, resizing, and noise to the input data.
354
-
355
- Args:
356
- data (numpy.ndarray): The input data.
357
- labels (numpy.ndarray, optional): The flow and cellprob labels associated with the data. Defaults to None.
358
- diams (float, optional): The diameter of the objects. Defaults to None.
359
- poisson (float, optional): The Poisson noise probability. Defaults to 0.7.
360
- blur (float, optional): The blur probability. Defaults to 0.7.
361
- downsample (float, optional): The downsample probability. Defaults to 0.0.
362
- beta (float, optional): The beta value for the poisson noise distribution. Defaults to 0.7.
363
- gblur (float, optional): The Gaussian blur level. Defaults to 1.0.
364
- diam_mean (float, optional): The mean diameter. Defaults to 30.
365
- ds_max (int, optional): The maximum downsample value. Defaults to 7.
366
- iso (bool, optional): Whether to apply isotropic augmentation. Defaults to True.
367
- rotate (bool, optional): Whether to apply rotation augmentation. Defaults to True.
368
- device (torch.device, optional): The device to use. Defaults to torch.device("cuda").
369
- xy (tuple, optional): The size of the output image. Defaults to (224, 224).
370
- nchan_noise (int, optional): The number of channels to add noise to. Defaults to 1.
371
- keep_raw (bool, optional): Whether to keep the raw image. Defaults to True.
372
-
373
- Returns:
374
- torch.Tensor: The augmented image and augmented noisy/blurry/downsampled version of image.
375
- torch.Tensor: The augmented labels.
376
- float: The scale factor applied to the image.
377
- """
378
- if device == None:
379
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None
380
-
381
- diams = 30 if diams is None else diams
382
- random_diam = diam_mean * (2**(2 * np.random.rand(len(data)) - 1))
383
- random_rsc = diams / random_diam #/ random_diam
384
- #rsc /= random_scale
385
- xy0 = (340, 340)
386
- nchan = data[0].shape[0]
387
- data_new = np.zeros((len(data), (1 + keep_raw) * nchan, xy0[0], xy0[1]), "float32")
388
- labels_new = np.zeros((len(data), 3, xy0[0], xy0[1]), "float32")
389
- for i in range(
390
- len(data)): #, (sc, img, lbl) in enumerate(zip(random_rsc, data, labels)):
391
- sc = random_rsc[i]
392
- img = data[i]
393
- lbl = labels[i] if labels is not None else None
394
- # create affine transform to resize
395
- Ly, Lx = img.shape[-2:]
396
- dxy = np.maximum(0, np.array([Lx / sc - xy0[1], Ly / sc - xy0[0]]))
397
- dxy = (np.random.rand(2,) - .5) * dxy
398
- cc = np.array([Lx / 2, Ly / 2])
399
- cc1 = cc - np.array([Lx - xy0[1], Ly - xy0[0]]) / 2 + dxy
400
- pts1 = np.float32([cc, cc + np.array([1, 0]), cc + np.array([0, 1])])
401
- pts2 = np.float32(
402
- [cc1, cc1 + np.array([1, 0]) / sc, cc1 + np.array([0, 1]) / sc])
403
- M = cv2.getAffineTransform(pts1, pts2)
404
-
405
- # apply to image
406
- for c in range(nchan):
407
- img_rsz = cv2.warpAffine(img[c], M, xy0, flags=cv2.INTER_LINEAR)
408
- #img_noise = add_noise(torch.from_numpy(img_rsz).to(device).unsqueeze(0)).cpu().numpy().squeeze(0)
409
- data_new[i, c] = img_rsz
410
- if keep_raw:
411
- data_new[i, c + nchan] = img_rsz
412
-
413
- if lbl is not None:
414
- # apply to labels
415
- labels_new[i, 0] = cv2.warpAffine(lbl[0], M, xy0, flags=cv2.INTER_NEAREST)
416
- labels_new[i, 1] = cv2.warpAffine(lbl[1], M, xy0, flags=cv2.INTER_LINEAR)
417
- labels_new[i, 2] = cv2.warpAffine(lbl[2], M, xy0, flags=cv2.INTER_LINEAR)
418
-
419
- rsc = random_diam / diam_mean
420
-
421
- # add noise before augmentations
422
- img = torch.from_numpy(data_new).to(device)
423
- img = torch.clamp(img, 0.)
424
- # just add noise to cyto if nchan_noise=1
425
- img[:, :nchan_noise] = add_noise(
426
- img[:, :nchan_noise], poisson=poisson, blur=blur, ds_max=ds_max, iso=iso,
427
- downsample=downsample, beta=beta, gblur=gblur,
428
- diams=torch.from_numpy(random_diam).to(device).float())
429
- # img -= img.mean(dim=(-2,-1), keepdim=True)
430
- # img /= img.std(dim=(-2,-1), keepdim=True) + 1e-3
431
- img = img.cpu().numpy()
432
-
433
- # augmentations
434
- img, lbl, scale = transforms.random_rotate_and_resize(
435
- img,
436
- Y=labels_new,
437
- xy=xy,
438
- rotate=False if not iso else rotate,
439
- #(iso and downsample==0),
440
- rescale=rsc,
441
- scale_range=0.5)
442
- img = torch.from_numpy(img).to(device)
443
- lbl = torch.from_numpy(lbl).to(device)
444
-
445
- return img, lbl, scale
446
-
447
-
448
- def one_chan_cellpose(device, model_type="cyto2", pretrained_model=None):
449
- """
450
- Creates a Cellpose network with a single input channel.
451
-
452
- Args:
453
- device (str): The device to run the network on.
454
- model_type (str, optional): The type of Cellpose model to use. Defaults to "cyto2".
455
- pretrained_model (str, optional): The path to a pretrained model file. Defaults to None.
456
-
457
- Returns:
458
- torch.nn.Module: The Cellpose network with a single input channel.
459
- """
460
- if pretrained_model is not None and not os.path.exists(pretrained_model):
461
- model_type = pretrained_model
462
- pretrained_model = None
463
- nbase = [32, 64, 128, 256]
464
- nchan = 1
465
- net1 = resnet_torch.CPnet([nchan, *nbase], nout=3, sz=3).to(device)
466
- filename = model_path(model_type,
467
- 0) if pretrained_model is None else pretrained_model
468
- weights = torch.load(filename, weights_only=True)
469
- zp = 0
470
- print(filename)
471
- for name in net1.state_dict():
472
- if ("res_down_0.conv.conv_0" not in name and
473
- #"output" not in name and
474
- "res_down_0.proj" not in name and name != "diam_mean" and
475
- name != "diam_labels"):
476
- net1.state_dict()[name].copy_(weights[name])
477
- elif "res_down_0" in name:
478
- if len(weights[name].shape) > 0:
479
- new_weight = torch.zeros_like(net1.state_dict()[name])
480
- if weights[name].shape[0] == 2:
481
- new_weight[:] = weights[name][0]
482
- elif len(weights[name].shape) > 1 and weights[name].shape[1] == 2:
483
- new_weight[:, zp] = weights[name][:, 0]
484
- else:
485
- new_weight = weights[name]
486
- else:
487
- new_weight = weights[name]
488
- net1.state_dict()[name].copy_(new_weight)
489
- return net1
490
-
491
-
492
- class CellposeDenoiseModel():
493
- """ model to run Cellpose and Image restoration """
494
-
495
- def __init__(self, gpu=False, pretrained_model=False, model_type=None,
496
- restore_type="denoise_cyto3", nchan=2,
497
- chan2_restore=False, device=None):
498
-
499
- self.dn = DenoiseModel(gpu=gpu, model_type=restore_type, chan2=chan2_restore,
500
- device=device)
501
- self.cp = CellposeModel(gpu=gpu, model_type=model_type, nchan=nchan,
502
- pretrained_model=pretrained_model, device=device)
503
-
504
- def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
505
- normalize=True, rescale=None, diameter=None, tile_overlap=0.1,
506
- augment=False, resample=True, invert=False, flow_threshold=0.4,
507
- cellprob_threshold=0.0, do_3D=False, anisotropy=None, stitch_threshold=0.0,
508
- min_size=15, niter=None, interp=True, bsize=224, flow3D_smooth=0):
509
- """
510
- Restore array or list of images using the image restoration model, and then segment.
511
-
512
- Args:
513
- x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images
514
- batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU
515
- (can make smaller or bigger depending on GPU memory usage). Defaults to 8.
516
- channels (list, optional): list of channels, either of length 2 or of length number of images by 2.
517
- First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue).
518
- Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue).
519
- For instance, to segment grayscale images, input [0,0]. To segment images with cells
520
- in green and nuclei in blue, input [2,3]. To segment one grayscale image and one
521
- image with cells in green and nuclei in blue, input [[0,0], [2,3]].
522
- Defaults to None.
523
- channel_axis (int, optional): channel axis in element of list x, or of np.ndarray x.
524
- if None, channels dimension is attempted to be automatically determined. Defaults to None.
525
- z_axis (int, optional): z axis in element of list x, or of np.ndarray x.
526
- if None, z dimension is attempted to be automatically determined. Defaults to None.
527
- normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel;
528
- can also pass dictionary of parameters (all keys are optional, default values shown):
529
- - "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored)
530
- - "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels
531
- - "normalize"=True ; run normalization (if False, all following parameters ignored)
532
- - "percentile"=None : pass in percentiles to use as list [perc_low, perc_high]
533
- - "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100)
534
- - "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
535
- Defaults to True.
536
- rescale (float, optional): resize factor for each image, if None, set to 1.0;
537
- (only used if diameter is None). Defaults to None.
538
- diameter (float, optional): diameter for each image,
539
- if diameter is None, set to diam_mean or diam_train if available. Defaults to None.
540
- tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1.
541
- augment (bool, optional): augment tiles by flipping and averaging for segmentation. Defaults to False.
542
- resample (bool, optional): run dynamics at original image size (will be slower but create more accurate boundaries). Defaults to True.
543
- invert (bool, optional): invert image pixel intensity before running network. Defaults to False.
544
- flow_threshold (float, optional): flow error threshold (all cells with errors below threshold are kept) (not used for 3D). Defaults to 0.4.
545
- cellprob_threshold (float, optional): all pixels with value above threshold kept for masks, decrease to find more and larger masks. Defaults to 0.0.
546
- do_3D (bool, optional): set to True to run 3D segmentation on 3D/4D image input. Defaults to False.
547
- anisotropy (float, optional): for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Defaults to None.
548
- stitch_threshold (float, optional): if stitch_threshold>0.0 and not do_3D, masks are stitched in 3D to return volume segmentation. Defaults to 0.0.
549
- min_size (int, optional): all ROIs below this size, in pixels, will be discarded. Defaults to 15.
550
- flow3D_smooth (int, optional): if do_3D and flow3D_smooth>0, smooth flows with gaussian filter of this stddev. Defaults to 0.
551
- niter (int, optional): number of iterations for dynamics computation. if None, it is set proportional to the diameter. Defaults to None.
552
- interp (bool, optional): interpolate during 2D dynamics (not available in 3D) . Defaults to True.
553
-
554
- Returns:
555
- A tuple containing (masks, flows, styles, imgs); masks: labelled image(s), where 0=no masks; 1,2,...=mask labels;
556
- flows: list of lists: flows[k][0] = XY flow in HSV 0-255; flows[k][1] = XY(Z) flows at each pixel; flows[k][2] = cell probability (if > cellprob_threshold, pixel used for dynamics); flows[k][3] = final pixel locations after Euler integration;
557
- styles: style vector summarizing each image of size 256;
558
- imgs: Restored images.
559
- """
560
-
561
- if isinstance(normalize, dict):
562
- normalize_params = {**normalize_default, **normalize}
563
- elif not isinstance(normalize, bool):
564
- raise ValueError("normalize parameter must be a bool or a dict")
565
- else:
566
- normalize_params = normalize_default
567
- normalize_params["normalize"] = normalize
568
- normalize_params["invert"] = invert
569
-
570
- img_restore = self.dn.eval(x, batch_size=batch_size, channels=channels,
571
- channel_axis=channel_axis, z_axis=z_axis,
572
- do_3D=do_3D,
573
- normalize=normalize_params, rescale=rescale,
574
- diameter=diameter,
575
- tile_overlap=tile_overlap, bsize=bsize)
576
-
577
- # turn off special normalization for segmentation
578
- normalize_params = normalize_default
579
-
580
- # change channels for segmentation
581
- if channels is not None:
582
- channels_new = [0, 0] if channels[0] == 0 else [1, 2]
583
- else:
584
- channels_new = None
585
- # change diameter if self.ratio > 1 (upsampled to self.dn.diam_mean)
586
- diameter = self.dn.diam_mean if self.dn.ratio > 1 else diameter
587
- masks, flows, styles = self.cp.eval(
588
- img_restore, batch_size=batch_size, channels=channels_new, channel_axis=-1,
589
- z_axis=0 if not isinstance(img_restore, list) and img_restore.ndim > 3 and img_restore.shape[0] > 0 else None,
590
- normalize=normalize_params, rescale=rescale, diameter=diameter,
591
- tile_overlap=tile_overlap, augment=augment, resample=resample,
592
- invert=invert, flow_threshold=flow_threshold,
593
- cellprob_threshold=cellprob_threshold, do_3D=do_3D, anisotropy=anisotropy,
594
- stitch_threshold=stitch_threshold, min_size=min_size, niter=niter,
595
- interp=interp, bsize=bsize)
596
-
597
- return masks, flows, styles, img_restore
598
-
599
-
600
- class DenoiseModel():
601
- """
602
- DenoiseModel class for denoising images using Cellpose denoising model.
603
-
604
- Args:
605
- gpu (bool, optional): Whether to use GPU for computation. Defaults to False.
606
- pretrained_model (bool or str or Path, optional): Pretrained model to use for denoising.
607
- Can be a string or path. Defaults to False.
608
- nchan (int, optional): Number of channels in the input images, all Cellpose 3 models were trained with nchan=1. Defaults to 1.
609
- model_type (str, optional): Type of pretrained model to use ("denoise_cyto3", "deblur_cyto3", "upsample_cyto3", ...). Defaults to None.
610
- chan2 (bool, optional): Whether to use a separate model for the second channel. Defaults to False.
611
- diam_mean (float, optional): Mean diameter of the objects in the images. Defaults to 30.0.
612
- device (torch.device, optional): Device to use for computation. Defaults to None.
613
-
614
- Attributes:
615
- nchan (int): Number of channels in the input images.
616
- diam_mean (float): Mean diameter of the objects in the images.
617
- net (CPnet): Cellpose network for denoising.
618
- pretrained_model (bool or str or Path): Pretrained model path to use for denoising.
619
- net_chan2 (CPnet or None): Cellpose network for the second channel, if applicable.
620
- net_type (str): Type of the denoising network.
621
-
622
- Methods:
623
- eval(x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
624
- normalize=True, rescale=None, diameter=None, tile=True, tile_overlap=0.1)
625
- Denoise array or list of images using the denoising model.
626
-
627
- _eval(net, x, normalize=True, rescale=None, diameter=None, tile=True,
628
- tile_overlap=0.1)
629
- Run denoising model on a single channel.
630
- """
631
-
632
- def __init__(self, gpu=False, pretrained_model=False, nchan=1, model_type=None,
633
- chan2=False, diam_mean=30., device=None):
634
- self.nchan = nchan
635
- if pretrained_model and (not isinstance(pretrained_model, str) and
636
- not isinstance(pretrained_model, Path)):
637
- raise ValueError("pretrained_model must be a string or path")
638
-
639
- self.diam_mean = diam_mean
640
- builtin = True
641
- if model_type is not None or (pretrained_model and
642
- not os.path.exists(pretrained_model)):
643
- pretrained_model_string = model_type if model_type is not None else "denoise_cyto3"
644
- if ~np.any([pretrained_model_string == s for s in MODEL_NAMES]):
645
- pretrained_model_string = "denoise_cyto3"
646
- pretrained_model = model_path(pretrained_model_string)
647
- if (pretrained_model and not os.path.exists(pretrained_model)):
648
- denoise_logger.warning("pretrained model has incorrect path")
649
- denoise_logger.info(f">> {pretrained_model_string} << model set to be used")
650
- self.diam_mean = 17. if "nuclei" in pretrained_model_string else 30.
651
- else:
652
- if pretrained_model:
653
- builtin = False
654
- pretrained_model_string = pretrained_model
655
- denoise_logger.info(f">>>> loading model {pretrained_model_string}")
656
-
657
- # assign network device
658
- if device is None:
659
- sdevice, gpu = assign_device(use_torch=True, gpu=gpu)
660
- self.device = device if device is not None else sdevice
661
- if device is not None:
662
- device_gpu = self.device.type == "cuda"
663
- self.gpu = gpu if device is None else device_gpu
664
-
665
- # create network
666
- self.nchan = nchan
667
- self.nclasses = 1
668
- nbase = [32, 64, 128, 256]
669
- self.nchan = nchan
670
- self.nbase = [nchan, *nbase]
671
-
672
- self.net = CPnet(self.nbase, self.nclasses, sz=3,
673
- max_pool=True, diam_mean=diam_mean).to(self.device)
674
-
675
- self.pretrained_model = pretrained_model
676
- self.net_chan2 = None
677
- if self.pretrained_model:
678
- self.net.load_model(self.pretrained_model, device=self.device)
679
- denoise_logger.info(
680
- f">>>> model diam_mean = {self.diam_mean: .3f} (ROIs rescaled to this size during training)"
681
- )
682
- if chan2 and builtin:
683
- chan2_path = model_path(
684
- os.path.split(self.pretrained_model)[-1].split("_")[0] + "_nuclei")
685
- print(f"loading model for chan2: {os.path.split(str(chan2_path))[-1]}")
686
- self.net_chan2 = CPnet(self.nbase, self.nclasses, sz=3,
687
- max_pool=True,
688
- diam_mean=17.).to(self.device)
689
- self.net_chan2.load_model(chan2_path, device=self.device)
690
- self.net_type = "cellpose_denoise"
691
-
692
- def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
693
- normalize=True, rescale=None, diameter=None, tile=True, do_3D=False,
694
- tile_overlap=0.1, bsize=224):
695
- """
696
- Restore array or list of images using the image restoration model.
697
-
698
- Args:
699
- x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images
700
- batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU
701
- (can make smaller or bigger depending on GPU memory usage). Defaults to 8.
702
- channels (list, optional): list of channels, either of length 2 or of length number of images by 2.
703
- First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue).
704
- Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue).
705
- For instance, to segment grayscale images, input [0,0]. To segment images with cells
706
- in green and nuclei in blue, input [2,3]. To segment one grayscale image and one
707
- image with cells in green and nuclei in blue, input [[0,0], [2,3]].
708
- Defaults to None.
709
- channel_axis (int, optional): channel axis in element of list x, or of np.ndarray x.
710
- if None, channels dimension is attempted to be automatically determined. Defaults to None.
711
- z_axis (int, optional): z axis in element of list x, or of np.ndarray x.
712
- if None, z dimension is attempted to be automatically determined. Defaults to None.
713
- normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel;
714
- can also pass dictionary of parameters (all keys are optional, default values shown):
715
- - "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored)
716
- - "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels
717
- - "normalize"=True ; run normalization (if False, all following parameters ignored)
718
- - "percentile"=None : pass in percentiles to use as list [perc_low, perc_high]
719
- - "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100)
720
- - "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
721
- Defaults to True.
722
- rescale (float, optional): resize factor for each image, if None, set to 1.0;
723
- (only used if diameter is None). Defaults to None.
724
- diameter (float, optional): diameter for each image,
725
- if diameter is None, set to diam_mean or diam_train if available. Defaults to None.
726
- tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1.
727
-
728
- Returns:
729
- list: A list of 2D/3D arrays of restored images
730
-
731
- """
732
- if isinstance(x, list) or x.squeeze().ndim == 5:
733
- tqdm_out = utils.TqdmToLogger(denoise_logger, level=logging.INFO)
734
- nimg = len(x)
735
- iterator = trange(nimg, file=tqdm_out,
736
- mininterval=30) if nimg > 1 else range(nimg)
737
- imgs = []
738
- for i in iterator:
739
- imgi = self.eval(
740
- x[i], batch_size=batch_size,
741
- channels=channels[i] if channels is not None and
742
- ((len(channels) == len(x) and
743
- (isinstance(channels[i], list) or
744
- isinstance(channels[i], np.ndarray)) and len(channels[i]) == 2))
745
- else channels, channel_axis=channel_axis, z_axis=z_axis,
746
- normalize=normalize,
747
- do_3D=do_3D,
748
- rescale=rescale[i] if isinstance(rescale, list) or
749
- isinstance(rescale, np.ndarray) else rescale,
750
- diameter=diameter[i] if isinstance(diameter, list) or
751
- isinstance(diameter, np.ndarray) else diameter,
752
- tile_overlap=tile_overlap, bsize=bsize)
753
- imgs.append(imgi)
754
- if isinstance(x, np.ndarray):
755
- imgs = np.array(imgs)
756
- return imgs
757
-
758
- else:
759
- # reshape image
760
- x = transforms.convert_image(x, channels, channel_axis=channel_axis,
761
- z_axis=z_axis, do_3D=do_3D, nchan=None)
762
- if x.ndim < 4:
763
- squeeze = True
764
- x = x[np.newaxis, ...]
765
- else:
766
- squeeze = False
767
-
768
- # may need to interpolate image before running upsampling
769
- self.ratio = 1.
770
- if "upsample" in self.pretrained_model:
771
- Ly, Lx = x.shape[-3:-1]
772
- if diameter is not None and 3 <= diameter < self.diam_mean:
773
- self.ratio = self.diam_mean / diameter
774
- denoise_logger.info(
775
- f"upsampling image to {self.diam_mean} pixel diameter ({self.ratio:0.2f} times)"
776
- )
777
- Lyr, Lxr = int(Ly * self.ratio), int(Lx * self.ratio)
778
- x = transforms.resize_image(x, Ly=Lyr, Lx=Lxr)
779
- else:
780
- denoise_logger.warning(
781
- f"not interpolating image before upsampling because diameter is set >= {self.diam_mean}"
782
- )
783
- #raise ValueError(f"diameter is set to {diameter}, needs to be >=3 and < {self.dn.diam_mean}")
784
-
785
- self.batch_size = batch_size
786
-
787
- if diameter is not None and diameter > 0:
788
- rescale = self.diam_mean / diameter
789
- elif rescale is None:
790
- rescale = 1.0
791
-
792
- if np.ptp(x[..., -1]) < 1e-3 or (channels is not None and channels[-1] == 0):
793
- x = x[..., :1]
794
-
795
- for c in range(x.shape[-1]):
796
- rescale0 = rescale * 30. / 17. if c == 1 else rescale
797
- if c == 0 or self.net_chan2 is None:
798
- x[...,
799
- c] = self._eval(self.net, x[..., c:c + 1], batch_size=batch_size,
800
- normalize=normalize, rescale=rescale0,
801
- tile_overlap=tile_overlap, bsize=bsize)[...,0]
802
- else:
803
- x[...,
804
- c] = self._eval(self.net_chan2, x[...,
805
- c:c + 1], batch_size=batch_size,
806
- normalize=normalize, rescale=rescale0,
807
- tile_overlap=tile_overlap, bsize=bsize)[...,0]
808
- x = x[0] if squeeze else x
809
- return x
810
-
811
- def _eval(self, net, x, batch_size=8, normalize=True, rescale=None,
812
- tile_overlap=0.1, bsize=224):
813
- """
814
- Run image restoration model on a single channel.
815
-
816
- Args:
817
- x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images
818
- batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU
819
- (can make smaller or bigger depending on GPU memory usage). Defaults to 8.
820
- normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel;
821
- can also pass dictionary of parameters (all keys are optional, default values shown):
822
- - "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored)
823
- - "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels
824
- - "normalize"=True ; run normalization (if False, all following parameters ignored)
825
- - "percentile"=None : pass in percentiles to use as list [perc_low, perc_high]
826
- - "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100)
827
- - "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
828
- Defaults to True.
829
- rescale (float, optional): resize factor for each image, if None, set to 1.0;
830
- (only used if diameter is None). Defaults to None.
831
- tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1.
832
-
833
- Returns:
834
- list: A list of 2D/3D arrays of restored images
835
-
836
- """
837
- if isinstance(normalize, dict):
838
- normalize_params = {**normalize_default, **normalize}
839
- elif not isinstance(normalize, bool):
840
- raise ValueError("normalize parameter must be a bool or a dict")
841
- else:
842
- normalize_params = normalize_default
843
- normalize_params["normalize"] = normalize
844
-
845
- tic = time.time()
846
- shape = x.shape
847
- nimg = shape[0]
848
-
849
- do_normalization = True if normalize_params["normalize"] else False
850
-
851
- img = np.asarray(x)
852
- if do_normalization:
853
- img = transforms.normalize_img(img, **normalize_params)
854
- if rescale != 1.0:
855
- img = transforms.resize_image(img, rsz=rescale)
856
- yf, style = run_net(self.net, img, bsize=bsize,
857
- tile_overlap=tile_overlap)
858
- yf = transforms.resize_image(yf, shape[1], shape[2])
859
- imgs = yf
860
- del yf, style
861
-
862
- # imgs = np.zeros((*x.shape[:-1], 1), np.float32)
863
- # for i in iterator:
864
- # img = np.asarray(x[i])
865
- # if do_normalization:
866
- # img = transforms.normalize_img(img, **normalize_params)
867
- # if rescale != 1.0:
868
- # img = transforms.resize_image(img, rsz=[rescale, rescale])
869
- # if img.ndim == 2:
870
- # img = img[:, :, np.newaxis]
871
- # yf, style = run_net(net, img, batch_size=batch_size, augment=False,
872
- # tile=tile, tile_overlap=tile_overlap, bsize=bsize)
873
- # img = transforms.resize_image(yf, Ly=x.shape[-3], Lx=x.shape[-2])
874
-
875
- # if img.ndim == 2:
876
- # img = img[:, :, np.newaxis]
877
- # imgs[i] = img
878
- # del yf, style
879
- net_time = time.time() - tic
880
- if nimg > 1:
881
- denoise_logger.info("imgs denoised in %2.2fs" % (net_time))
882
-
883
- return imgs
884
-
885
-
886
- def train(net, train_data=None, train_labels=None, train_files=None, test_data=None,
887
- test_labels=None, test_files=None, train_probs=None, test_probs=None,
888
- lam=[1., 1.5, 0.], scale_range=0.5, seg_model_type="cyto2", save_path=None,
889
- save_every=100, save_each=False, poisson=0.7, beta=0.7, blur=0.7, gblur=1.0,
890
- iso=True, uniform_blur=False, downsample=0., ds_max=7,
891
- learning_rate=0.005, n_epochs=500,
892
- weight_decay=0.00001, batch_size=8, nimg_per_epoch=None,
893
- nimg_test_per_epoch=None, model_name=None):
894
-
895
- # net properties
896
- device = net.device
897
- nchan = net.nchan
898
- diam_mean = net.diam_mean.item()
899
-
900
- args = np.array([poisson, beta, blur, gblur, downsample])
901
- if args.ndim == 1:
902
- args = args[:, np.newaxis]
903
- poisson, beta, blur, gblur, downsample = args
904
- nnoise = len(poisson)
905
-
906
- d = datetime.datetime.now()
907
- if save_path is not None:
908
- if model_name is None:
909
- filename = ""
910
- lstrs = ["per", "seg", "rec"]
911
- for k, (l, s) in enumerate(zip(lam, lstrs)):
912
- filename += f"{s}_{l:.2f}_"
913
- if not iso:
914
- filename += "aniso_"
915
- if poisson.sum() > 0:
916
- filename += "poisson_"
917
- if blur.sum() > 0:
918
- filename += "blur_"
919
- if downsample.sum() > 0:
920
- filename += "downsample_"
921
- filename += d.strftime("%Y_%m_%d_%H_%M_%S.%f")
922
- filename = os.path.join(save_path, filename)
923
- else:
924
- filename = os.path.join(save_path, model_name)
925
- print(filename)
926
- for i in range(len(poisson)):
927
- denoise_logger.info(
928
- f"poisson: {poisson[i]: 0.2f}, beta: {beta[i]: 0.2f}, blur: {blur[i]: 0.2f}, gblur: {gblur[i]: 0.2f}, downsample: {downsample[i]: 0.2f}"
929
- )
930
- net1 = one_chan_cellpose(device=device, pretrained_model=seg_model_type)
931
-
932
- learning_rate_const = learning_rate
933
- LR = np.linspace(0, learning_rate_const, 10)
934
- LR = np.append(LR, learning_rate_const * np.ones(n_epochs - 100))
935
- for i in range(10):
936
- LR = np.append(LR, LR[-1] / 2 * np.ones(10))
937
- learning_rate = LR
938
-
939
- batch_size = 8
940
- optimizer = torch.optim.AdamW(net.parameters(), lr=learning_rate[0],
941
- weight_decay=weight_decay)
942
- if train_data is not None:
943
- nimg = len(train_data)
944
- diam_train = np.array(
945
- [utils.diameters(train_labels[k])[0] for k in trange(len(train_labels))])
946
- diam_train[diam_train < 5] = 5.
947
- if test_data is not None:
948
- diam_test = np.array(
949
- [utils.diameters(test_labels[k])[0] for k in trange(len(test_labels))])
950
- diam_test[diam_test < 5] = 5.
951
- nimg_test = len(test_data)
952
- else:
953
- nimg = len(train_files)
954
- denoise_logger.info(">>> using files instead of loading dataset")
955
- train_labels_files = [str(tf)[:-4] + f"_flows.tif" for tf in train_files]
956
- denoise_logger.info(">>> computing diameters")
957
- diam_train = np.array([
958
- utils.diameters(io.imread(train_labels_files[k])[0])[0]
959
- for k in trange(len(train_labels_files))
960
- ])
961
- diam_train[diam_train < 5] = 5.
962
- if test_files is not None:
963
- nimg_test = len(test_files)
964
- test_labels_files = [str(tf)[:-4] + f"_flows.tif" for tf in test_files]
965
- diam_test = np.array([
966
- utils.diameters(io.imread(test_labels_files[k])[0])[0]
967
- for k in trange(len(test_labels_files))
968
- ])
969
- diam_test[diam_test < 5] = 5.
970
- train_probs = 1. / nimg * np.ones(nimg,
971
- "float64") if train_probs is None else train_probs
972
- if test_files is not None or test_data is not None:
973
- test_probs = 1. / nimg_test * np.ones(
974
- nimg_test, "float64") if test_probs is None else test_probs
975
-
976
- tic = time.time()
977
-
978
- nimg_per_epoch = nimg if nimg_per_epoch is None else nimg_per_epoch
979
- if test_files is not None or test_data is not None:
980
- nimg_test_per_epoch = nimg_test if nimg_test_per_epoch is None else nimg_test_per_epoch
981
-
982
- nbatch = 0
983
- train_losses, test_losses = [], []
984
- for iepoch in range(n_epochs):
985
- np.random.seed(iepoch)
986
- rperm = np.random.choice(np.arange(0, nimg), size=(nimg_per_epoch,),
987
- p=train_probs)
988
- torch.manual_seed(iepoch)
989
- np.random.seed(iepoch)
990
- for param_group in optimizer.param_groups:
991
- param_group["lr"] = learning_rate[iepoch]
992
- lavg, lavg_per, nsum = 0, 0, 0
993
- for ibatch in range(0, nimg_per_epoch, batch_size * nnoise):
994
- inds = rperm[ibatch : ibatch + batch_size * nnoise]
995
- if train_data is None:
996
- imgs = [np.maximum(0, io.imread(train_files[i])[:nchan]) for i in inds]
997
- lbls = [io.imread(train_labels_files[i])[1:] for i in inds]
998
- else:
999
- imgs = [train_data[i][:nchan] for i in inds]
1000
- lbls = [train_labels[i][1:] for i in inds]
1001
- #inoise = nbatch % nnoise
1002
- rnoise = np.random.permutation(nnoise)
1003
- for i, inoise in enumerate(rnoise):
1004
- if i * batch_size < len(imgs):
1005
- imgi, lbli, scale = random_rotate_and_resize_noise(
1006
- imgs[i * batch_size : (i + 1) * batch_size],
1007
- lbls[i * batch_size : (i + 1) * batch_size],
1008
- diam_train[inds][i * batch_size : (i + 1) * batch_size].copy(),
1009
- poisson=poisson[inoise],
1010
- beta=beta[inoise], gblur=gblur[inoise], blur=blur[inoise], iso=iso,
1011
- downsample=downsample[inoise], uniform_blur=uniform_blur,
1012
- diam_mean=diam_mean, ds_max=ds_max,
1013
- device=device)
1014
- if i == 0:
1015
- img = imgi
1016
- lbl = lbli
1017
- else:
1018
- img = torch.cat((img, imgi), axis=0)
1019
- lbl = torch.cat((lbl, lbli), axis=0)
1020
-
1021
- if nnoise > 0:
1022
- iperm = np.random.permutation(img.shape[0])
1023
- img, lbl = img[iperm], lbl[iperm]
1024
-
1025
- for i in range(nnoise):
1026
- optimizer.zero_grad()
1027
- imgi = img[i * batch_size: (i + 1) * batch_size]
1028
- lbli = lbl[i * batch_size: (i + 1) * batch_size]
1029
- if imgi.shape[0] > 0:
1030
- loss, loss_per = train_loss(net, imgi[:, :nchan], net1=net1,
1031
- img=imgi[:, nchan:], lbl=lbli, lam=lam)
1032
- loss.backward()
1033
- optimizer.step()
1034
- lavg += loss.item() * imgi.shape[0]
1035
- lavg_per += loss_per.item() * imgi.shape[0]
1036
-
1037
- nsum += len(img)
1038
- nbatch += 1
1039
-
1040
- if iepoch % 5 == 0 or iepoch < 10:
1041
- lavg = lavg / nsum
1042
- lavg_per = lavg_per / nsum
1043
- if test_data is not None or test_files is not None:
1044
- lavgt, nsum = 0., 0
1045
- np.random.seed(42)
1046
- rperm = np.random.choice(np.arange(0, nimg_test),
1047
- size=(nimg_test_per_epoch,), p=test_probs)
1048
- inoise = iepoch % nnoise
1049
- torch.manual_seed(inoise)
1050
- for ibatch in range(0, nimg_test_per_epoch, batch_size):
1051
- inds = rperm[ibatch:ibatch + batch_size]
1052
- if test_data is None:
1053
- imgs = [
1054
- np.maximum(0,
1055
- io.imread(test_files[i])[:nchan]) for i in inds
1056
- ]
1057
- lbls = [io.imread(test_labels_files[i])[1:] for i in inds]
1058
- else:
1059
- imgs = [test_data[i][:nchan] for i in inds]
1060
- lbls = [test_labels[i][1:] for i in inds]
1061
- img, lbl, scale = random_rotate_and_resize_noise(
1062
- imgs, lbls, diam_test[inds].copy(), poisson=poisson[inoise],
1063
- beta=beta[inoise], blur=blur[inoise], gblur=gblur[inoise],
1064
- iso=iso, downsample=downsample[inoise], uniform_blur=uniform_blur,
1065
- diam_mean=diam_mean, ds_max=ds_max, device=device)
1066
- loss, loss_per = test_loss(net, img[:, :nchan], net1=net1,
1067
- img=img[:, nchan:], lbl=lbl, lam=lam)
1068
-
1069
- lavgt += loss.item() * img.shape[0]
1070
- nsum += len(img)
1071
- lavgt = lavgt / nsum
1072
- denoise_logger.info(
1073
- "Epoch %d, Time %4.1fs, Loss %0.3f, loss_per %0.3f, Loss Test %0.3f, LR %2.4f"
1074
- % (iepoch, time.time() - tic, lavg, lavg_per, lavgt,
1075
- learning_rate[iepoch]))
1076
- test_losses.append(lavgt)
1077
- else:
1078
- denoise_logger.info(
1079
- "Epoch %d, Time %4.1fs, Loss %0.3f, loss_per %0.3f, LR %2.4f" %
1080
- (iepoch, time.time() - tic, lavg, lavg_per, learning_rate[iepoch]))
1081
- train_losses.append(lavg)
1082
-
1083
- if save_path is not None:
1084
- if iepoch == n_epochs - 1 or (iepoch % save_every == 0 and iepoch != 0):
1085
- if save_each: #separate files as model progresses
1086
- filename0 = str(filename) + f"_epoch_{iepoch:%04d}"
1087
- else:
1088
- filename0 = filename
1089
- denoise_logger.info(f"saving network parameters to {filename0}")
1090
- net.save_model(filename0)
1091
- else:
1092
- filename = save_path
1093
-
1094
- return filename, train_losses, test_losses
1095
-
1096
-
1097
- if __name__ == "__main__":
1098
- import argparse
1099
- parser = argparse.ArgumentParser(description="cellpose parameters")
1100
-
1101
- input_img_args = parser.add_argument_group("input image arguments")
1102
- input_img_args.add_argument("--dir", default=[], type=str,
1103
- help="folder containing data to run or train on.")
1104
- input_img_args.add_argument("--img_filter", default=[], type=str,
1105
- help="end string for images to run on")
1106
-
1107
- model_args = parser.add_argument_group("model arguments")
1108
- model_args.add_argument("--pretrained_model", default=[], type=str,
1109
- help="pretrained denoising model")
1110
-
1111
- training_args = parser.add_argument_group("training arguments")
1112
- training_args.add_argument("--test_dir", default=[], type=str,
1113
- help="folder containing test data (optional)")
1114
- training_args.add_argument("--file_list", default=[], type=str,
1115
- help="npy file containing list of train and test files")
1116
- training_args.add_argument("--seg_model_type", default="cyto2", type=str,
1117
- help="model to use for seg training loss")
1118
- training_args.add_argument(
1119
- "--noise_type", default=[], type=str,
1120
- help="noise type to use (if input, then other noise params are ignored)")
1121
- training_args.add_argument("--poisson", default=0.8, type=float,
1122
- help="fraction of images to add poisson noise to")
1123
- training_args.add_argument("--beta", default=0.7, type=float,
1124
- help="scale of poisson noise")
1125
- training_args.add_argument("--blur", default=0., type=float,
1126
- help="fraction of images to blur")
1127
- training_args.add_argument("--gblur", default=1.0, type=float,
1128
- help="scale of gaussian blurring stddev")
1129
- training_args.add_argument("--downsample", default=0., type=float,
1130
- help="fraction of images to downsample")
1131
- training_args.add_argument("--ds_max", default=7, type=int,
1132
- help="max downsampling factor")
1133
- training_args.add_argument("--lam_per", default=1.0, type=float,
1134
- help="weighting of perceptual loss")
1135
- training_args.add_argument("--lam_seg", default=1.5, type=float,
1136
- help="weighting of segmentation loss")
1137
- training_args.add_argument("--lam_rec", default=0., type=float,
1138
- help="weighting of reconstruction loss")
1139
- training_args.add_argument(
1140
- "--diam_mean", default=30., type=float, help=
1141
- "mean diameter to resize cells to during training -- if starting from pretrained models it cannot be changed from 30.0"
1142
- )
1143
- training_args.add_argument("--learning_rate", default=0.001, type=float,
1144
- help="learning rate. Default: %(default)s")
1145
- training_args.add_argument("--n_epochs", default=2000, type=int,
1146
- help="number of epochs. Default: %(default)s")
1147
- training_args.add_argument(
1148
- "--save_each", default=False, action="store_true",
1149
- help="save each epoch as separate model")
1150
- training_args.add_argument(
1151
- "--nimg_per_epoch", default=0, type=int,
1152
- help="number of images per epoch. Default is length of training images")
1153
- training_args.add_argument(
1154
- "--nimg_test_per_epoch", default=0, type=int,
1155
- help="number of test images per epoch. Default is length of testing images")
1156
-
1157
- io.logger_setup()
1158
-
1159
- args = parser.parse_args()
1160
- lams = [args.lam_per, args.lam_seg, args.lam_rec]
1161
- print("lam", lams)
1162
-
1163
- if len(args.noise_type) > 0:
1164
- noise_type = args.noise_type
1165
- uniform_blur = False
1166
- iso = True
1167
- if noise_type == "poisson":
1168
- poisson = 0.8
1169
- blur = 0.
1170
- downsample = 0.
1171
- beta = 0.7
1172
- gblur = 1.0
1173
- elif noise_type == "blur_expr":
1174
- poisson = 0.8
1175
- blur = 0.8
1176
- downsample = 0.
1177
- beta = 0.1
1178
- gblur = 0.5
1179
- elif noise_type == "blur":
1180
- poisson = 0.8
1181
- blur = 0.8
1182
- downsample = 0.
1183
- beta = 0.1
1184
- gblur = 10.0
1185
- uniform_blur = True
1186
- elif noise_type == "downsample_expr":
1187
- poisson = 0.8
1188
- blur = 0.8
1189
- downsample = 0.8
1190
- beta = 0.03
1191
- gblur = 1.0
1192
- elif noise_type == "downsample":
1193
- poisson = 0.8
1194
- blur = 0.8
1195
- downsample = 0.8
1196
- beta = 0.03
1197
- gblur = 5.0
1198
- uniform_blur = True
1199
- elif noise_type == "all":
1200
- poisson = [0.8, 0.8, 0.8]
1201
- blur = [0., 0.8, 0.8]
1202
- downsample = [0., 0., 0.8]
1203
- beta = [0.7, 0.1, 0.03]
1204
- gblur = [0., 10.0, 5.0]
1205
- uniform_blur = True
1206
- elif noise_type == "aniso":
1207
- poisson = 0.8
1208
- blur = 0.8
1209
- downsample = 0.8
1210
- beta = 0.1
1211
- gblur = args.ds_max * 1.5
1212
- iso = False
1213
- else:
1214
- raise ValueError(f"{noise_type} noise_type is not supported")
1215
- else:
1216
- poisson, beta = args.poisson, args.beta
1217
- blur, gblur = args.blur, args.gblur
1218
- downsample = args.downsample
1219
-
1220
- pretrained_model = None if len(
1221
- args.pretrained_model) == 0 else args.pretrained_model
1222
- model = DenoiseModel(gpu=True, nchan=1, diam_mean=args.diam_mean,
1223
- pretrained_model=pretrained_model)
1224
-
1225
- train_data, labels, train_files, train_probs = None, None, None, None
1226
- test_data, test_labels, test_files, test_probs = None, None, None, None
1227
- if len(args.file_list) == 0:
1228
- output = io.load_train_test_data(args.dir, args.test_dir, "_img", "_masks", 0)
1229
- images, labels, image_names, test_images, test_labels, image_names_test = output
1230
- train_data = []
1231
- for i in range(len(images)):
1232
- img = images[i].astype("float32")
1233
- if img.ndim > 2:
1234
- img = img[0]
1235
- train_data.append(
1236
- np.maximum(transforms.normalize99(img), 0)[np.newaxis, :, :])
1237
- if len(args.test_dir) > 0:
1238
- test_data = []
1239
- for i in range(len(test_images)):
1240
- img = test_images[i].astype("float32")
1241
- if img.ndim > 2:
1242
- img = img[0]
1243
- test_data.append(
1244
- np.maximum(transforms.normalize99(img), 0)[np.newaxis, :, :])
1245
- save_path = os.path.join(args.dir, "../models/")
1246
- else:
1247
- root = args.dir
1248
- denoise_logger.info(
1249
- ">>> using file_list (assumes images are normalized and have flows!)")
1250
- dat = np.load(args.file_list, allow_pickle=True).item()
1251
- train_files = dat["train_files"]
1252
- test_files = dat["test_files"]
1253
- train_probs = dat["train_probs"] if "train_probs" in dat else None
1254
- test_probs = dat["test_probs"] if "test_probs" in dat else None
1255
- if str(train_files[0])[:len(str(root))] != str(root):
1256
- for i in range(len(train_files)):
1257
- new_path = root / Path(*train_files[i].parts[-3:])
1258
- if i == 0:
1259
- print(f"changing path from {train_files[i]} to {new_path}")
1260
- train_files[i] = new_path
1261
-
1262
- for i in range(len(test_files)):
1263
- new_path = root / Path(*test_files[i].parts[-3:])
1264
- test_files[i] = new_path
1265
- save_path = os.path.join(args.dir, "models/")
1266
-
1267
- os.makedirs(save_path, exist_ok=True)
1268
-
1269
- nimg_per_epoch = None if args.nimg_per_epoch == 0 else args.nimg_per_epoch
1270
- nimg_test_per_epoch = None if args.nimg_test_per_epoch == 0 else args.nimg_test_per_epoch
1271
-
1272
- model_path = train(
1273
- model.net, train_data=train_data, train_labels=labels, train_files=train_files,
1274
- test_data=test_data, test_labels=test_labels, test_files=test_files,
1275
- train_probs=train_probs, test_probs=test_probs, poisson=poisson, beta=beta,
1276
- blur=blur, gblur=gblur, downsample=downsample, ds_max=args.ds_max,
1277
- iso=iso, uniform_blur=uniform_blur, n_epochs=args.n_epochs,
1278
- learning_rate=args.learning_rate,
1279
- lam=lams,
1280
- seg_model_type=args.seg_model_type, nimg_per_epoch=nimg_per_epoch,
1281
- nimg_test_per_epoch=nimg_test_per_epoch, save_path=save_path)
1282
-
1283
-
1284
- def seg_train_noisy(model, train_data, train_labels, test_data=None, test_labels=None,
1285
- poisson=0.8, blur=0.0, downsample=0.0, save_path=None,
1286
- save_every=100, save_each=False, learning_rate=0.2, n_epochs=500,
1287
- momentum=0.9, weight_decay=0.00001, SGD=True, batch_size=8,
1288
- nimg_per_epoch=None, diameter=None, rescale=True, z_masking=False,
1289
- model_name=None):
1290
- """ train function uses loss function model.loss_fn in models.py
1291
-
1292
- (data should already be normalized)
1293
-
1294
- """
1295
-
1296
- d = datetime.datetime.now()
1297
-
1298
- model.n_epochs = n_epochs
1299
- if isinstance(learning_rate, (list, np.ndarray)):
1300
- if isinstance(learning_rate, np.ndarray) and learning_rate.ndim > 1:
1301
- raise ValueError("learning_rate.ndim must equal 1")
1302
- elif len(learning_rate) != n_epochs:
1303
- raise ValueError(
1304
- "if learning_rate given as list or np.ndarray it must have length n_epochs"
1305
- )
1306
- model.learning_rate = learning_rate
1307
- model.learning_rate_const = mode(learning_rate)[0][0]
1308
- else:
1309
- model.learning_rate_const = learning_rate
1310
- # set learning rate schedule
1311
- if SGD:
1312
- LR = np.linspace(0, model.learning_rate_const, 10)
1313
- if model.n_epochs > 250:
1314
- LR = np.append(
1315
- LR, model.learning_rate_const * np.ones(model.n_epochs - 100))
1316
- for i in range(10):
1317
- LR = np.append(LR, LR[-1] / 2 * np.ones(10))
1318
- else:
1319
- LR = np.append(
1320
- LR,
1321
- model.learning_rate_const * np.ones(max(0, model.n_epochs - 10)))
1322
- else:
1323
- LR = model.learning_rate_const * np.ones(model.n_epochs)
1324
- model.learning_rate = LR
1325
-
1326
- model.batch_size = batch_size
1327
- model._set_optimizer(model.learning_rate[0], momentum, weight_decay, SGD)
1328
- model._set_criterion()
1329
-
1330
- nimg = len(train_data)
1331
-
1332
- # compute average cell diameter
1333
- if diameter is None:
1334
- diam_train = np.array(
1335
- [utils.diameters(train_labels[k][0])[0] for k in range(len(train_labels))])
1336
- diam_train_mean = diam_train[diam_train > 0].mean()
1337
- model.diam_labels = diam_train_mean
1338
- if rescale:
1339
- diam_train[diam_train < 5] = 5.
1340
- if test_data is not None:
1341
- diam_test = np.array([
1342
- utils.diameters(test_labels[k][0])[0]
1343
- for k in range(len(test_labels))
1344
- ])
1345
- diam_test[diam_test < 5] = 5.
1346
- denoise_logger.info(">>>> median diameter set to = %d" % model.diam_mean)
1347
- elif rescale:
1348
- diam_train_mean = diameter
1349
- model.diam_labels = diameter
1350
- denoise_logger.info(">>>> median diameter set to = %d" % model.diam_mean)
1351
- diam_train = diameter * np.ones(len(train_labels), "float32")
1352
- if test_data is not None:
1353
- diam_test = diameter * np.ones(len(test_labels), "float32")
1354
-
1355
- denoise_logger.info(
1356
- f">>>> mean of training label mask diameters (saved to model) {diam_train_mean:.3f}"
1357
- )
1358
- model.net.diam_labels.data = torch.ones(1, device=model.device) * diam_train_mean
1359
-
1360
- nchan = train_data[0].shape[0]
1361
- denoise_logger.info(">>>> training network with %d channel input <<<<" % nchan)
1362
- denoise_logger.info(">>>> LR: %0.5f, batch_size: %d, weight_decay: %0.5f" %
1363
- (model.learning_rate_const, model.batch_size, weight_decay))
1364
-
1365
- if test_data is not None:
1366
- denoise_logger.info(f">>>> ntrain = {nimg}, ntest = {len(test_data)}")
1367
- else:
1368
- denoise_logger.info(f">>>> ntrain = {nimg}")
1369
-
1370
- tic = time.time()
1371
-
1372
- lavg, nsum = 0, 0
1373
-
1374
- if save_path is not None:
1375
- _, file_label = os.path.split(save_path)
1376
- file_path = os.path.join(save_path, "models/")
1377
-
1378
- if not os.path.exists(file_path):
1379
- os.makedirs(file_path)
1380
- else:
1381
- denoise_logger.warning("WARNING: no save_path given, model not saving")
1382
-
1383
- ksave = 0
1384
-
1385
- # get indices for each epoch for training
1386
- np.random.seed(0)
1387
- inds_all = np.zeros((0,), "int32")
1388
- if nimg_per_epoch is None or nimg > nimg_per_epoch:
1389
- nimg_per_epoch = nimg
1390
- denoise_logger.info(f">>>> nimg_per_epoch = {nimg_per_epoch}")
1391
- while len(inds_all) < n_epochs * nimg_per_epoch:
1392
- rperm = np.random.permutation(nimg)
1393
- inds_all = np.hstack((inds_all, rperm))
1394
-
1395
- for iepoch in range(model.n_epochs):
1396
- if SGD:
1397
- model._set_learning_rate(model.learning_rate[iepoch])
1398
- np.random.seed(iepoch)
1399
- rperm = inds_all[iepoch * nimg_per_epoch:(iepoch + 1) * nimg_per_epoch]
1400
- for ibatch in range(0, nimg_per_epoch, batch_size):
1401
- inds = rperm[ibatch:ibatch + batch_size]
1402
- imgi, lbl, scale = random_rotate_and_resize_noise(
1403
- [train_data[i] for i in inds], [train_labels[i][1:] for i in inds],
1404
- poisson=poisson, blur=blur, downsample=downsample,
1405
- diams=diam_train[inds], diam_mean=model.diam_mean)
1406
- imgi = imgi[:, :1] # keep noisy only
1407
- if z_masking:
1408
- nc = imgi.shape[1]
1409
- nb = imgi.shape[0]
1410
- ncmin = (np.random.rand(nb) > 0.25) * (np.random.randint(
1411
- nc // 2 - 1, size=nb))
1412
- ncmax = nc - (np.random.rand(nb) > 0.25) * (np.random.randint(
1413
- nc // 2 - 1, size=nb))
1414
- for b in range(nb):
1415
- imgi[b, :ncmin[b]] = 0
1416
- imgi[b, ncmax[b]:] = 0
1417
-
1418
- train_loss = model._train_step(imgi, lbl)
1419
- lavg += train_loss
1420
- nsum += len(imgi)
1421
-
1422
- if iepoch % 10 == 0 or iepoch == 5:
1423
- lavg = lavg / nsum
1424
- if test_data is not None:
1425
- lavgt, nsum = 0., 0
1426
- np.random.seed(42)
1427
- rperm = np.arange(0, len(test_data), 1, int)
1428
- for ibatch in range(0, len(test_data), batch_size):
1429
- inds = rperm[ibatch:ibatch + batch_size]
1430
- imgi, lbl, scale = random_rotate_and_resize_noise(
1431
- [test_data[i] for i in inds],
1432
- [test_labels[i][1:] for i in inds], poisson=poisson, blur=blur,
1433
- downsample=downsample, diams=diam_test[inds],
1434
- diam_mean=model.diam_mean)
1435
- imgi = imgi[:, :1] # keep noisy only
1436
- test_loss = model._test_eval(imgi, lbl)
1437
- lavgt += test_loss
1438
- nsum += len(imgi)
1439
-
1440
- denoise_logger.info(
1441
- "Epoch %d, Time %4.1fs, Loss %2.4f, Loss Test %2.4f, LR %2.4f" %
1442
- (iepoch, time.time() - tic, lavg, lavgt / nsum,
1443
- model.learning_rate[iepoch]))
1444
- else:
1445
- denoise_logger.info(
1446
- "Epoch %d, Time %4.1fs, Loss %2.4f, LR %2.4f" %
1447
- (iepoch, time.time() - tic, lavg, model.learning_rate[iepoch]))
1448
-
1449
- lavg, nsum = 0, 0
1450
-
1451
- if save_path is not None:
1452
- if iepoch == model.n_epochs - 1 or iepoch % save_every == 1:
1453
- # save model at the end
1454
- if save_each: #separate files as model progresses
1455
- if model_name is None:
1456
- filename = "{}_{}_{}_{}".format(
1457
- model.net_type, file_label,
1458
- d.strftime("%Y_%m_%d_%H_%M_%S.%f"), "epoch_" + str(iepoch))
1459
- else:
1460
- filename = "{}_{}".format(model_name, "epoch_" + str(iepoch))
1461
- else:
1462
- if model_name is None:
1463
- filename = "{}_{}_{}".format(model.net_type, file_label,
1464
- d.strftime("%Y_%m_%d_%H_%M_%S.%f"))
1465
- else:
1466
- filename = model_name
1467
- filename = os.path.join(file_path, filename)
1468
- ksave += 1
1469
- denoise_logger.info(f"saving network parameters to {filename}")
1470
- model.net.save_model(filename)
1471
- else:
1472
- filename = save_path
1473
-
1474
- return filename
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/seg_post_model/cellpose/export.py DELETED
@@ -1,405 +0,0 @@
1
- """Auxiliary module for bioimageio format export
2
-
3
- Example usage:
4
-
5
- ```bash
6
- #!/bin/bash
7
-
8
- # Define default paths and parameters
9
- DEFAULT_CHANNELS="1 0"
10
- DEFAULT_PATH_PRETRAINED_MODEL="/home/qinyu/models/cp/cellpose_residual_on_style_on_concatenation_off_1135_rest_2023_05_04_23_41_31.252995"
11
- DEFAULT_PATH_README="/home/qinyu/models/cp/README.md"
12
- DEFAULT_LIST_PATH_COVER_IMAGES="/home/qinyu/images/cp/cellpose_raw_and_segmentation.jpg /home/qinyu/images/cp/cellpose_raw_and_probability.jpg /home/qinyu/images/cp/cellpose_raw.jpg"
13
- DEFAULT_MODEL_ID="philosophical-panda"
14
- DEFAULT_MODEL_ICON="🐼"
15
- DEFAULT_MODEL_VERSION="0.1.0"
16
- DEFAULT_MODEL_NAME="My Cool Cellpose"
17
- DEFAULT_MODEL_DOCUMENTATION="A cool Cellpose model trained for my cool dataset."
18
- DEFAULT_MODEL_AUTHORS='[{"name": "Qin Yu", "affiliation": "EMBL", "github_user": "qin-yu", "orcid": "0000-0002-4652-0795"}]'
19
- DEFAULT_MODEL_CITE='[{"text": "For more details of the model itself, see the manuscript", "doi": "10.1242/dev.202800", "url": null}]'
20
- DEFAULT_MODEL_TAGS="cellpose 3d 2d"
21
- DEFAULT_MODEL_LICENSE="MIT"
22
- DEFAULT_MODEL_REPO="https://github.com/kreshuklab/go-nuclear"
23
-
24
- # Run the Python script with default parameters
25
- python export.py \
26
- --channels $DEFAULT_CHANNELS \
27
- --path_pretrained_model "$DEFAULT_PATH_PRETRAINED_MODEL" \
28
- --path_readme "$DEFAULT_PATH_README" \
29
- --list_path_cover_images $DEFAULT_LIST_PATH_COVER_IMAGES \
30
- --model_version "$DEFAULT_MODEL_VERSION" \
31
- --model_name "$DEFAULT_MODEL_NAME" \
32
- --model_documentation "$DEFAULT_MODEL_DOCUMENTATION" \
33
- --model_authors "$DEFAULT_MODEL_AUTHORS" \
34
- --model_cite "$DEFAULT_MODEL_CITE" \
35
- --model_tags $DEFAULT_MODEL_TAGS \
36
- --model_license "$DEFAULT_MODEL_LICENSE" \
37
- --model_repo "$DEFAULT_MODEL_REPO"
38
- ```
39
- """
40
-
41
- import os
42
- import sys
43
- import json
44
- import argparse
45
- from pathlib import Path
46
- from urllib.parse import urlparse
47
-
48
- import torch
49
- import numpy as np
50
-
51
- from cellpose.io import imread
52
- from cellpose.utils import download_url_to_file
53
- from cellpose.transforms import pad_image_ND, normalize_img, convert_image
54
- from cellpose.vit_sam import CPnetBioImageIO
55
-
56
- from bioimageio.spec.model.v0_5 import (
57
- ArchitectureFromFileDescr,
58
- Author,
59
- AxisId,
60
- ChannelAxis,
61
- CiteEntry,
62
- Doi,
63
- FileDescr,
64
- Identifier,
65
- InputTensorDescr,
66
- IntervalOrRatioDataDescr,
67
- LicenseId,
68
- ModelDescr,
69
- ModelId,
70
- OrcidId,
71
- OutputTensorDescr,
72
- ParameterizedSize,
73
- PytorchStateDictWeightsDescr,
74
- SizeReference,
75
- SpaceInputAxis,
76
- SpaceOutputAxis,
77
- TensorId,
78
- TorchscriptWeightsDescr,
79
- Version,
80
- WeightsDescr,
81
- )
82
- # Define ARBITRARY_SIZE if it is not available in the module
83
- try:
84
- from bioimageio.spec.model.v0_5 import ARBITRARY_SIZE
85
- except ImportError:
86
- ARBITRARY_SIZE = ParameterizedSize(min=1, step=1)
87
-
88
- from bioimageio.spec.common import HttpUrl
89
- from bioimageio.spec import save_bioimageio_package
90
- from bioimageio.core import test_model
91
-
92
- DEFAULT_CHANNELS = [2, 1]
93
- DEFAULT_NORMALIZE_PARAMS = {
94
- "axis": -1,
95
- "lowhigh": None,
96
- "percentile": None,
97
- "normalize": True,
98
- "norm3D": False,
99
- "sharpen_radius": 0,
100
- "smooth_radius": 0,
101
- "tile_norm_blocksize": 0,
102
- "tile_norm_smooth3D": 1,
103
- "invert": False,
104
- }
105
- IMAGE_URL = "http://www.cellpose.org/static/data/rgb_3D.tif"
106
-
107
-
108
- def download_and_normalize_image(path_dir_temp, channels=DEFAULT_CHANNELS):
109
- """
110
- Download and normalize image.
111
- """
112
- filename = os.path.basename(urlparse(IMAGE_URL).path)
113
- path_image = path_dir_temp / filename
114
- if not path_image.exists():
115
- sys.stderr.write(f'Downloading: "{IMAGE_URL}" to {path_image}\n')
116
- download_url_to_file(IMAGE_URL, path_image)
117
- img = imread(path_image).astype(np.float32)
118
- img = convert_image(img, channels, channel_axis=1, z_axis=0, do_3D=False, nchan=2)
119
- img = normalize_img(img, **DEFAULT_NORMALIZE_PARAMS)
120
- img = np.transpose(img, (0, 3, 1, 2))
121
- img, _, _ = pad_image_ND(img)
122
- return img
123
-
124
-
125
- def load_bioimageio_cpnet_model(path_model_weight, nchan=2):
126
- cpnet_kwargs = {
127
- "nout": 3,
128
- }
129
- cpnet_biio = CPnetBioImageIO(**cpnet_kwargs)
130
- state_dict_cuda = torch.load(path_model_weight, map_location=torch.device("cpu"), weights_only=True)
131
- cpnet_biio.load_state_dict(state_dict_cuda)
132
- cpnet_biio.eval() # crucial for the prediction results
133
- return cpnet_biio, cpnet_kwargs
134
-
135
-
136
- def descr_gen_input(path_test_input, nchan=2):
137
- input_axes = [
138
- SpaceInputAxis(id=AxisId("z"), size=ARBITRARY_SIZE),
139
- ChannelAxis(channel_names=[Identifier(f"c{i+1}") for i in range(nchan)]),
140
- SpaceInputAxis(id=AxisId("y"), size=ParameterizedSize(min=16, step=16)),
141
- SpaceInputAxis(id=AxisId("x"), size=ParameterizedSize(min=16, step=16)),
142
- ]
143
- data_descr = IntervalOrRatioDataDescr(type="float32")
144
- path_test_input = Path(path_test_input)
145
- descr_input = InputTensorDescr(
146
- id=TensorId("raw"),
147
- axes=input_axes,
148
- test_tensor=FileDescr(source=path_test_input),
149
- data=data_descr,
150
- )
151
- return descr_input
152
-
153
-
154
- def descr_gen_output_flow(path_test_output):
155
- output_axes_output_tensor = [
156
- SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))),
157
- ChannelAxis(channel_names=[Identifier("flow1"), Identifier("flow2"), Identifier("flow3")]),
158
- SpaceOutputAxis(id=AxisId("y"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("y"))),
159
- SpaceOutputAxis(id=AxisId("x"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("x"))),
160
- ]
161
- path_test_output = Path(path_test_output)
162
- descr_output = OutputTensorDescr(
163
- id=TensorId("flow"),
164
- axes=output_axes_output_tensor,
165
- test_tensor=FileDescr(source=path_test_output),
166
- )
167
- return descr_output
168
-
169
-
170
- def descr_gen_output_downsampled(path_dir_temp, nbase=None):
171
- if nbase is None:
172
- nbase = [32, 64, 128, 256]
173
-
174
- output_axes_downsampled_tensors = [
175
- [
176
- SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))),
177
- ChannelAxis(channel_names=[Identifier(f"feature{i+1}") for i in range(base)]),
178
- SpaceOutputAxis(
179
- id=AxisId("y"),
180
- size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("y")),
181
- scale=2**offset,
182
- ),
183
- SpaceOutputAxis(
184
- id=AxisId("x"),
185
- size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("x")),
186
- scale=2**offset,
187
- ),
188
- ]
189
- for offset, base in enumerate(nbase)
190
- ]
191
- path_downsampled_tensors = [
192
- Path(path_dir_temp / f"test_downsampled_{i}.npy") for i in range(len(output_axes_downsampled_tensors))
193
- ]
194
- descr_output_downsampled_tensors = [
195
- OutputTensorDescr(
196
- id=TensorId(f"downsampled_{i}"),
197
- axes=axes,
198
- test_tensor=FileDescr(source=path),
199
- )
200
- for i, (axes, path) in enumerate(zip(output_axes_downsampled_tensors, path_downsampled_tensors))
201
- ]
202
- return descr_output_downsampled_tensors
203
-
204
-
205
- def descr_gen_output_style(path_test_style, nchannel=256):
206
- output_axes_style_tensor = [
207
- SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))),
208
- ChannelAxis(channel_names=[Identifier(f"feature{i+1}") for i in range(nchannel)]),
209
- ]
210
- path_style_tensor = Path(path_test_style)
211
- descr_output_style_tensor = OutputTensorDescr(
212
- id=TensorId("style"),
213
- axes=output_axes_style_tensor,
214
- test_tensor=FileDescr(source=path_style_tensor),
215
- )
216
- return descr_output_style_tensor
217
-
218
-
219
- def descr_gen_arch(cpnet_kwargs, path_cpnet_wrapper=None):
220
- if path_cpnet_wrapper is None:
221
- path_cpnet_wrapper = Path(__file__).parent / "resnet_torch.py"
222
- pytorch_architecture = ArchitectureFromFileDescr(
223
- callable=Identifier("CPnetBioImageIO"),
224
- source=Path(path_cpnet_wrapper),
225
- kwargs=cpnet_kwargs,
226
- )
227
- return pytorch_architecture
228
-
229
-
230
- def descr_gen_documentation(path_doc, markdown_text):
231
- with open(path_doc, "w") as f:
232
- f.write(markdown_text)
233
-
234
-
235
- def package_to_bioimageio(
236
- path_pretrained_model,
237
- path_save_trace,
238
- path_readme,
239
- list_path_cover_images,
240
- descr_input,
241
- descr_output,
242
- descr_output_downsampled_tensors,
243
- descr_output_style_tensor,
244
- pytorch_version,
245
- pytorch_architecture,
246
- model_id,
247
- model_icon,
248
- model_version,
249
- model_name,
250
- model_documentation,
251
- model_authors,
252
- model_cite,
253
- model_tags,
254
- model_license,
255
- model_repo,
256
- ):
257
- """Package model description to BioImage.IO format."""
258
- my_model_descr = ModelDescr(
259
- id=ModelId(model_id) if model_id is not None else None,
260
- id_emoji=model_icon,
261
- version=Version(model_version),
262
- name=model_name,
263
- description=model_documentation,
264
- authors=[
265
- Author(
266
- name=author["name"],
267
- affiliation=author["affiliation"],
268
- github_user=author["github_user"],
269
- orcid=OrcidId(author["orcid"]),
270
- )
271
- for author in model_authors
272
- ],
273
- cite=[CiteEntry(text=cite["text"], doi=Doi(cite["doi"]), url=cite["url"]) for cite in model_cite],
274
- covers=[Path(img) for img in list_path_cover_images],
275
- license=LicenseId(model_license),
276
- tags=model_tags,
277
- documentation=Path(path_readme),
278
- git_repo=HttpUrl(model_repo),
279
- inputs=[descr_input],
280
- outputs=[descr_output, descr_output_style_tensor] + descr_output_downsampled_tensors,
281
- weights=WeightsDescr(
282
- pytorch_state_dict=PytorchStateDictWeightsDescr(
283
- source=Path(path_pretrained_model),
284
- architecture=pytorch_architecture,
285
- pytorch_version=pytorch_version,
286
- ),
287
- torchscript=TorchscriptWeightsDescr(
288
- source=Path(path_save_trace),
289
- pytorch_version=pytorch_version,
290
- parent="pytorch_state_dict", # these weights were converted from the pytorch_state_dict weights.
291
- ),
292
- ),
293
- )
294
-
295
- return my_model_descr
296
-
297
-
298
- def parse_args():
299
- # fmt: off
300
- parser = argparse.ArgumentParser(description="BioImage.IO model packaging for Cellpose")
301
- parser.add_argument("--channels", nargs=2, default=[2, 1], type=int, help="Cyto-only = [2, 0], Cyto + Nuclei = [2, 1], Nuclei-only = [1, 0]")
302
- parser.add_argument("--path_pretrained_model", required=True, type=str, help="Path to pretrained model file, e.g., cellpose_residual_on_style_on_concatenation_off_1135_rest_2023_05_04_23_41_31.252995")
303
- parser.add_argument("--path_readme", required=True, type=str, help="Path to README file")
304
- parser.add_argument("--list_path_cover_images", nargs='+', required=True, type=str, help="List of paths to cover images")
305
- parser.add_argument("--model_id", type=str, help="Model ID, provide if already exists", default=None)
306
- parser.add_argument("--model_icon", type=str, help="Model icon, provide if already exists", default=None)
307
- parser.add_argument("--model_version", required=True, type=str, help="Model version, new model should be 0.1.0")
308
- parser.add_argument("--model_name", required=True, type=str, help="Model name, e.g., My Cool Cellpose")
309
- parser.add_argument("--model_documentation", required=True, type=str, help="Model documentation, e.g., A cool Cellpose model trained for my cool dataset.")
310
- parser.add_argument("--model_authors", required=True, type=str, help="Model authors in JSON format, e.g., '[{\"name\": \"Qin Yu\", \"affiliation\": \"EMBL\", \"github_user\": \"qin-yu\", \"orcid\": \"0000-0002-4652-0795\"}]'")
311
- parser.add_argument("--model_cite", required=True, type=str, help="Model citation in JSON format, e.g., '[{\"text\": \"For more details of the model itself, see the manuscript\", \"doi\": \"10.1242/dev.202800\", \"url\": null}]'")
312
- parser.add_argument("--model_tags", nargs='+', required=True, type=str, help="Model tags, e.g., cellpose 3d 2d")
313
- parser.add_argument("--model_license", required=True, type=str, help="Model license, e.g., MIT")
314
- parser.add_argument("--model_repo", required=True, type=str, help="Model repository URL")
315
- return parser.parse_args()
316
- # fmt: on
317
-
318
-
319
- def main():
320
- args = parse_args()
321
-
322
- # Parse user-provided paths and arguments
323
- channels = args.channels
324
- model_cite = json.loads(args.model_cite)
325
- model_authors = json.loads(args.model_authors)
326
-
327
- path_readme = Path(args.path_readme)
328
- path_pretrained_model = Path(args.path_pretrained_model)
329
- list_path_cover_images = [Path(path_image) for path_image in args.list_path_cover_images]
330
-
331
- # Auto-generated paths
332
- path_cpnet_wrapper = Path(__file__).resolve().parent / "resnet_torch.py"
333
- path_dir_temp = Path(__file__).resolve().parent.parent / "models" / path_pretrained_model.stem
334
- path_dir_temp.mkdir(parents=True, exist_ok=True)
335
-
336
- path_save_trace = path_dir_temp / "cp_traced.pt"
337
- path_test_input = path_dir_temp / "test_input.npy"
338
- path_test_output = path_dir_temp / "test_output.npy"
339
- path_test_style = path_dir_temp / "test_style.npy"
340
- path_bioimageio_package = path_dir_temp / "cellpose_model.zip"
341
-
342
- # Download test input image
343
- img_np = download_and_normalize_image(path_dir_temp, channels=channels)
344
- np.save(path_test_input, img_np)
345
- img = torch.tensor(img_np).float()
346
-
347
- # Load model
348
- cpnet_biio, cpnet_kwargs = load_bioimageio_cpnet_model(path_pretrained_model)
349
-
350
- # Test model and save output
351
- tuple_output_tensor = cpnet_biio(img)
352
- np.save(path_test_output, tuple_output_tensor[0].detach().numpy())
353
- np.save(path_test_style, tuple_output_tensor[1].detach().numpy())
354
- for i, t in enumerate(tuple_output_tensor[2:]):
355
- np.save(path_dir_temp / f"test_downsampled_{i}.npy", t.detach().numpy())
356
-
357
- # Save traced model
358
- model_traced = torch.jit.trace(cpnet_biio, img)
359
- model_traced.save(path_save_trace)
360
-
361
- # Generate model description
362
- descr_input = descr_gen_input(path_test_input)
363
- descr_output = descr_gen_output_flow(path_test_output)
364
- descr_output_downsampled_tensors = descr_gen_output_downsampled(path_dir_temp, nbase=cpnet_biio.nbase[1:])
365
- descr_output_style_tensor = descr_gen_output_style(path_test_style, cpnet_biio.nbase[-1])
366
- pytorch_version = Version(torch.__version__)
367
- pytorch_architecture = descr_gen_arch(cpnet_kwargs, path_cpnet_wrapper)
368
-
369
- # Package model
370
- my_model_descr = package_to_bioimageio(
371
- path_pretrained_model,
372
- path_save_trace,
373
- path_readme,
374
- list_path_cover_images,
375
- descr_input,
376
- descr_output,
377
- descr_output_downsampled_tensors,
378
- descr_output_style_tensor,
379
- pytorch_version,
380
- pytorch_architecture,
381
- args.model_id,
382
- args.model_icon,
383
- args.model_version,
384
- args.model_name,
385
- args.model_documentation,
386
- model_authors,
387
- model_cite,
388
- args.model_tags,
389
- args.model_license,
390
- args.model_repo,
391
- )
392
-
393
- # Test model
394
- summary = test_model(my_model_descr, weight_format="pytorch_state_dict")
395
- summary.display()
396
- summary = test_model(my_model_descr, weight_format="torchscript")
397
- summary.display()
398
-
399
- # Save BioImage.IO package
400
- package_path = save_bioimageio_package(my_model_descr, output_path=Path(path_bioimageio_package))
401
- print("package path:", package_path)
402
-
403
-
404
- if __name__ == "__main__":
405
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/seg_post_model/cellpose/gui/gui.py DELETED
@@ -1,2007 +0,0 @@
1
- """
2
- Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer, Michael Rariden and Marius Pachitariu.
3
- """
4
-
5
- import sys, os, pathlib, warnings, datetime, time, copy
6
-
7
- from qtpy import QtGui, QtCore
8
- from superqt import QRangeSlider, QCollapsible
9
- from qtpy.QtWidgets import QScrollArea, QMainWindow, QApplication, QWidget, QScrollBar, \
10
- QComboBox, QGridLayout, QPushButton, QFrame, QCheckBox, QLabel, QProgressBar, \
11
- QLineEdit, QMessageBox, QGroupBox, QMenu, QAction
12
- import pyqtgraph as pg
13
-
14
- import numpy as np
15
- from scipy.stats import mode
16
- import cv2
17
-
18
- from . import guiparts, menus, io
19
- from .. import models, core, dynamics, version, train
20
- from ..utils import download_url_to_file, masks_to_outlines, diameters
21
- from ..io import get_image_files, imsave, imread
22
- from ..transforms import resize_image, normalize99, normalize99_tile, smooth_sharpen_img
23
- from ..models import normalize_default
24
- from ..plot import disk
25
-
26
- try:
27
- import matplotlib.pyplot as plt
28
- MATPLOTLIB = True
29
- except:
30
- MATPLOTLIB = False
31
-
32
- Horizontal = QtCore.Qt.Orientation.Horizontal
33
-
34
-
35
- class Slider(QRangeSlider):
36
-
37
- def __init__(self, parent, name, color):
38
- super().__init__(Horizontal)
39
- self.setEnabled(False)
40
- self.valueChanged.connect(lambda: self.levelChanged(parent))
41
- self.name = name
42
-
43
- self.setStyleSheet(""" QSlider{
44
- background-color: transparent;
45
- }
46
- """)
47
- self.show()
48
-
49
- def levelChanged(self, parent):
50
- parent.level_change(self.name)
51
-
52
-
53
- class QHLine(QFrame):
54
-
55
- def __init__(self):
56
- super(QHLine, self).__init__()
57
- self.setFrameShape(QFrame.HLine)
58
- self.setLineWidth(8)
59
-
60
-
61
- def make_bwr():
62
- # make a bwr colormap
63
- b = np.append(255 * np.ones(128), np.linspace(0, 255, 128)[::-1])[:, np.newaxis]
64
- r = np.append(np.linspace(0, 255, 128), 255 * np.ones(128))[:, np.newaxis]
65
- g = np.append(np.linspace(0, 255, 128),
66
- np.linspace(0, 255, 128)[::-1])[:, np.newaxis]
67
- color = np.concatenate((r, g, b), axis=-1).astype(np.uint8)
68
- bwr = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color)
69
- return bwr
70
-
71
-
72
- def make_spectral():
73
- # make spectral colormap
74
- r = np.array([
75
- 0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64, 68, 72, 76, 80,
76
- 84, 88, 92, 96, 100, 104, 108, 112, 116, 120, 124, 128, 128, 128, 128, 128, 128,
77
- 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 120, 112, 104, 96, 88,
78
- 80, 72, 64, 56, 48, 40, 32, 24, 16, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
79
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 7, 11, 15, 19, 23,
80
- 27, 31, 35, 39, 43, 47, 51, 55, 59, 63, 67, 71, 75, 79, 83, 87, 91, 95, 99, 103,
81
- 107, 111, 115, 119, 123, 127, 131, 135, 139, 143, 147, 151, 155, 159, 163, 167,
82
- 171, 175, 179, 183, 187, 191, 195, 199, 203, 207, 211, 215, 219, 223, 227, 231,
83
- 235, 239, 243, 247, 251, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
84
- 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
85
- 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
86
- 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
87
- 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
88
- 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
89
- 255, 255, 255, 255, 255
90
- ])
91
- g = np.array([
92
- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 5, 4, 4, 3, 3,
93
- 2, 2, 1, 1, 0, 0, 0, 7, 15, 23, 31, 39, 47, 55, 63, 71, 79, 87, 95, 103, 111,
94
- 119, 127, 135, 143, 151, 159, 167, 175, 183, 191, 199, 207, 215, 223, 231, 239,
95
- 247, 255, 247, 239, 231, 223, 215, 207, 199, 191, 183, 175, 167, 159, 151, 143,
96
- 135, 128, 129, 131, 132, 134, 135, 137, 139, 140, 142, 143, 145, 147, 148, 150,
97
- 151, 153, 154, 156, 158, 159, 161, 162, 164, 166, 167, 169, 170, 172, 174, 175,
98
- 177, 178, 180, 181, 183, 185, 186, 188, 189, 191, 193, 194, 196, 197, 199, 201,
99
- 202, 204, 205, 207, 208, 210, 212, 213, 215, 216, 218, 220, 221, 223, 224, 226,
100
- 228, 229, 231, 232, 234, 235, 237, 239, 240, 242, 243, 245, 247, 248, 250, 251,
101
- 253, 255, 251, 247, 243, 239, 235, 231, 227, 223, 219, 215, 211, 207, 203, 199,
102
- 195, 191, 187, 183, 179, 175, 171, 167, 163, 159, 155, 151, 147, 143, 139, 135,
103
- 131, 127, 123, 119, 115, 111, 107, 103, 99, 95, 91, 87, 83, 79, 75, 71, 67, 63,
104
- 59, 55, 51, 47, 43, 39, 35, 31, 27, 23, 19, 15, 11, 7, 3, 0, 8, 16, 24, 32, 41,
105
- 49, 57, 65, 74, 82, 90, 98, 106, 115, 123, 131, 139, 148, 156, 164, 172, 180,
106
- 189, 197, 205, 213, 222, 230, 238, 246, 254
107
- ])
108
- b = np.array([
109
- 0, 7, 15, 23, 31, 39, 47, 55, 63, 71, 79, 87, 95, 103, 111, 119, 127, 135, 143,
110
- 151, 159, 167, 175, 183, 191, 199, 207, 215, 223, 231, 239, 247, 255, 255, 255,
111
- 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
112
- 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 251, 247,
113
- 243, 239, 235, 231, 227, 223, 219, 215, 211, 207, 203, 199, 195, 191, 187, 183,
114
- 179, 175, 171, 167, 163, 159, 155, 151, 147, 143, 139, 135, 131, 128, 126, 124,
115
- 122, 120, 118, 116, 114, 112, 110, 108, 106, 104, 102, 100, 98, 96, 94, 92, 90,
116
- 88, 86, 84, 82, 80, 78, 76, 74, 72, 70, 68, 66, 64, 62, 60, 58, 56, 54, 52, 50,
117
- 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10,
118
- 8, 6, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
119
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
120
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 16, 24, 32, 41, 49, 57, 65, 74,
121
- 82, 90, 98, 106, 115, 123, 131, 139, 148, 156, 164, 172, 180, 189, 197, 205,
122
- 213, 222, 230, 238, 246, 254
123
- ])
124
- color = (np.vstack((r, g, b)).T).astype(np.uint8)
125
- spectral = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color)
126
- return spectral
127
-
128
-
129
- def make_cmap(cm=0):
130
- # make a single channel colormap
131
- r = np.arange(0, 256)
132
- color = np.zeros((256, 3))
133
- color[:, cm] = r
134
- color = color.astype(np.uint8)
135
- cmap = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color)
136
- return cmap
137
-
138
-
139
- def run(image=None):
140
- from ..io import logger_setup
141
- logger, log_file = logger_setup()
142
- # Always start by initializing Qt (only once per application)
143
- warnings.filterwarnings("ignore")
144
- app = QApplication(sys.argv)
145
- icon_path = pathlib.Path.home().joinpath(".cellpose", "logo.png")
146
- guip_path = pathlib.Path.home().joinpath(".cellpose", "cellposeSAM_gui.png")
147
- if not icon_path.is_file():
148
- cp_dir = pathlib.Path.home().joinpath(".cellpose")
149
- cp_dir.mkdir(exist_ok=True)
150
- print("downloading logo")
151
- download_url_to_file(
152
- "https://www.cellpose.org/static/images/cellpose_transparent.png",
153
- icon_path, progress=True)
154
- if not guip_path.is_file():
155
- print("downloading help window image")
156
- download_url_to_file("https://www.cellpose.org/static/images/cellposeSAM_gui.png",
157
- guip_path, progress=True)
158
- icon_path = str(icon_path.resolve())
159
- app_icon = QtGui.QIcon()
160
- app_icon.addFile(icon_path, QtCore.QSize(16, 16))
161
- app_icon.addFile(icon_path, QtCore.QSize(24, 24))
162
- app_icon.addFile(icon_path, QtCore.QSize(32, 32))
163
- app_icon.addFile(icon_path, QtCore.QSize(48, 48))
164
- app_icon.addFile(icon_path, QtCore.QSize(64, 64))
165
- app_icon.addFile(icon_path, QtCore.QSize(256, 256))
166
- app.setWindowIcon(app_icon)
167
- app.setStyle("Fusion")
168
- app.setPalette(guiparts.DarkPalette())
169
- MainW(image=image, logger=logger)
170
- ret = app.exec_()
171
- sys.exit(ret)
172
-
173
-
174
- class MainW(QMainWindow):
175
-
176
- def __init__(self, image=None, logger=None):
177
- super(MainW, self).__init__()
178
-
179
- self.logger = logger
180
- pg.setConfigOptions(imageAxisOrder="row-major")
181
- self.setGeometry(50, 50, 1200, 1000)
182
- self.setWindowTitle(f"cellpose v{version}")
183
- self.cp_path = os.path.dirname(os.path.realpath(__file__))
184
- app_icon = QtGui.QIcon()
185
- icon_path = pathlib.Path.home().joinpath(".cellpose", "logo.png")
186
- icon_path = str(icon_path.resolve())
187
- app_icon.addFile(icon_path, QtCore.QSize(16, 16))
188
- app_icon.addFile(icon_path, QtCore.QSize(24, 24))
189
- app_icon.addFile(icon_path, QtCore.QSize(32, 32))
190
- app_icon.addFile(icon_path, QtCore.QSize(48, 48))
191
- app_icon.addFile(icon_path, QtCore.QSize(64, 64))
192
- app_icon.addFile(icon_path, QtCore.QSize(256, 256))
193
- self.setWindowIcon(app_icon)
194
- # rgb(150,255,150)
195
- self.setStyleSheet(guiparts.stylesheet())
196
-
197
- menus.mainmenu(self)
198
- menus.editmenu(self)
199
- menus.modelmenu(self)
200
- menus.helpmenu(self)
201
-
202
- self.stylePressed = """QPushButton {Text-align: center;
203
- background-color: rgb(150,50,150);
204
- border-color: white;
205
- color:white;}
206
- QToolTip {
207
- background-color: black;
208
- color: white;
209
- border: black solid 1px
210
- }"""
211
- self.styleUnpressed = """QPushButton {Text-align: center;
212
- background-color: rgb(50,50,50);
213
- border-color: white;
214
- color:white;}
215
- QToolTip {
216
- background-color: black;
217
- color: white;
218
- border: black solid 1px
219
- }"""
220
- self.loaded = False
221
-
222
- # ---- MAIN WIDGET LAYOUT ---- #
223
- self.cwidget = QWidget(self)
224
- self.lmain = QGridLayout()
225
- self.cwidget.setLayout(self.lmain)
226
- self.setCentralWidget(self.cwidget)
227
- self.lmain.setVerticalSpacing(0)
228
- self.lmain.setContentsMargins(0, 0, 0, 10)
229
-
230
- self.imask = 0
231
- self.scrollarea = QScrollArea()
232
- self.scrollarea.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOn)
233
- self.scrollarea.setStyleSheet("""QScrollArea { border: none }""")
234
- self.scrollarea.setWidgetResizable(True)
235
- self.swidget = QWidget(self)
236
- self.scrollarea.setWidget(self.swidget)
237
- self.l0 = QGridLayout()
238
- self.swidget.setLayout(self.l0)
239
- b = self.make_buttons()
240
- self.lmain.addWidget(self.scrollarea, 0, 0, 39, 9)
241
-
242
- # ---- drawing area ---- #
243
- self.win = pg.GraphicsLayoutWidget()
244
-
245
- self.lmain.addWidget(self.win, 0, 9, 40, 30)
246
-
247
- self.win.scene().sigMouseClicked.connect(self.plot_clicked)
248
- self.win.scene().sigMouseMoved.connect(self.mouse_moved)
249
- self.make_viewbox()
250
- self.lmain.setColumnStretch(10, 1)
251
- bwrmap = make_bwr()
252
- self.bwr = bwrmap.getLookupTable(start=0.0, stop=255.0, alpha=False)
253
- self.cmap = []
254
- # spectral colormap
255
- self.cmap.append(make_spectral().getLookupTable(start=0.0, stop=255.0,
256
- alpha=False))
257
- # single channel colormaps
258
- for i in range(3):
259
- self.cmap.append(
260
- make_cmap(i).getLookupTable(start=0.0, stop=255.0, alpha=False))
261
-
262
- if MATPLOTLIB:
263
- self.colormap = (plt.get_cmap("gist_ncar")(np.linspace(0.0, .9, 1000000)) *
264
- 255).astype(np.uint8)
265
- np.random.seed(42) # make colors stable
266
- self.colormap = self.colormap[np.random.permutation(1000000)]
267
- else:
268
- np.random.seed(42) # make colors stable
269
- self.colormap = ((np.random.rand(1000000, 3) * 0.8 + 0.1) * 255).astype(
270
- np.uint8)
271
- self.NZ = 1
272
- self.restore = None
273
- self.ratio = 1.
274
- self.reset()
275
-
276
- # This needs to go after .reset() is called to get state fully set up:
277
- self.autobtn.checkStateChanged.connect(self.compute_saturation_if_checked)
278
-
279
- self.load_3D = False
280
-
281
- # if called with image, load it
282
- if image is not None:
283
- self.filename = image
284
- io._load_image(self, self.filename)
285
-
286
- # training settings
287
- d = datetime.datetime.now()
288
- self.training_params = {
289
- "model_index": 0,
290
- "learning_rate": 1e-5,
291
- "weight_decay": 0.1,
292
- "n_epochs": 100,
293
- "model_name": "cpsam" + d.strftime("_%Y%m%d_%H%M%S"),
294
- }
295
-
296
- self.stitch_threshold = 0.
297
- self.flow3D_smooth = 0.
298
- self.anisotropy = 1.
299
- self.min_size = 15
300
-
301
- self.setAcceptDrops(True)
302
- self.win.show()
303
- self.show()
304
-
305
- def help_window(self):
306
- HW = guiparts.HelpWindow(self)
307
- HW.show()
308
-
309
- def train_help_window(self):
310
- THW = guiparts.TrainHelpWindow(self)
311
- THW.show()
312
-
313
- def gui_window(self):
314
- EG = guiparts.ExampleGUI(self)
315
- EG.show()
316
-
317
- def make_buttons(self):
318
- self.boldfont = QtGui.QFont("Arial", 11, QtGui.QFont.Bold)
319
- self.boldmedfont = QtGui.QFont("Arial", 9, QtGui.QFont.Bold)
320
- self.medfont = QtGui.QFont("Arial", 9)
321
- self.smallfont = QtGui.QFont("Arial", 8)
322
-
323
- b = 0
324
- self.satBox = QGroupBox("Views")
325
- self.satBox.setFont(self.boldfont)
326
- self.satBoxG = QGridLayout()
327
- self.satBox.setLayout(self.satBoxG)
328
- self.l0.addWidget(self.satBox, b, 0, 1, 9)
329
-
330
- widget_row = 0
331
- self.view = 0 # 0=image, 1=flowsXY, 2=flowsZ, 3=cellprob
332
- self.color = 0 # 0=RGB, 1=gray, 2=R, 3=G, 4=B
333
- self.RGBDropDown = QComboBox()
334
- self.RGBDropDown.addItems(
335
- ["RGB", "red=R", "green=G", "blue=B", "gray", "spectral"])
336
- self.RGBDropDown.setFont(self.medfont)
337
- self.RGBDropDown.currentIndexChanged.connect(self.color_choose)
338
- self.satBoxG.addWidget(self.RGBDropDown, widget_row, 0, 1, 3)
339
-
340
- label = QLabel("<p>[&uarr; / &darr; or W/S]</p>")
341
- label.setFont(self.smallfont)
342
- self.satBoxG.addWidget(label, widget_row, 3, 1, 3)
343
- label = QLabel("[R / G / B \n toggles color ]")
344
- label.setFont(self.smallfont)
345
- self.satBoxG.addWidget(label, widget_row, 6, 1, 3)
346
-
347
- widget_row += 1
348
- self.ViewDropDown = QComboBox()
349
- self.ViewDropDown.addItems(["image", "gradXY", "cellprob", "restored"])
350
- self.ViewDropDown.setFont(self.medfont)
351
- self.ViewDropDown.model().item(3).setEnabled(False)
352
- self.ViewDropDown.currentIndexChanged.connect(self.update_plot)
353
- self.satBoxG.addWidget(self.ViewDropDown, widget_row, 0, 2, 3)
354
-
355
- label = QLabel("[pageup / pagedown]")
356
- label.setFont(self.smallfont)
357
- self.satBoxG.addWidget(label, widget_row, 3, 1, 5)
358
-
359
- widget_row += 2
360
- label = QLabel("")
361
- label.setToolTip(
362
- "NOTE: manually changing the saturation bars does not affect normalization in segmentation"
363
- )
364
- self.satBoxG.addWidget(label, widget_row, 0, 1, 5)
365
-
366
- self.autobtn = QCheckBox("auto-adjust saturation")
367
- self.autobtn.setToolTip("sets scale-bars as normalized for segmentation")
368
- self.autobtn.setFont(self.medfont)
369
- self.autobtn.setChecked(True)
370
- self.satBoxG.addWidget(self.autobtn, widget_row, 1, 1, 8)
371
-
372
- widget_row += 1
373
- self.sliders = []
374
- colors = [[255, 0, 0], [0, 255, 0], [0, 0, 255], [100, 100, 100]]
375
- colornames = ["red", "Chartreuse", "DodgerBlue"]
376
- names = ["red", "green", "blue"]
377
- for r in range(3):
378
- widget_row += 1
379
- if r == 0:
380
- label = QLabel('<font color="gray">gray/</font><br>red')
381
- else:
382
- label = QLabel(names[r] + ":")
383
- label.setStyleSheet(f"color: {colornames[r]}")
384
- label.setFont(self.boldmedfont)
385
- self.satBoxG.addWidget(label, widget_row, 0, 1, 2)
386
- self.sliders.append(Slider(self, names[r], colors[r]))
387
- self.sliders[-1].setMinimum(-.1)
388
- self.sliders[-1].setMaximum(255.1)
389
- self.sliders[-1].setValue([0, 255])
390
- self.sliders[-1].setToolTip(
391
- "NOTE: manually changing the saturation bars does not affect normalization in segmentation"
392
- )
393
- self.satBoxG.addWidget(self.sliders[-1], widget_row, 2, 1, 7)
394
-
395
- b += 1
396
- self.drawBox = QGroupBox("Drawing")
397
- self.drawBox.setFont(self.boldfont)
398
- self.drawBoxG = QGridLayout()
399
- self.drawBox.setLayout(self.drawBoxG)
400
- self.l0.addWidget(self.drawBox, b, 0, 1, 9)
401
- self.autosave = True
402
-
403
- widget_row = 0
404
- self.brush_size = 3
405
- self.BrushChoose = QComboBox()
406
- self.BrushChoose.addItems(["1", "3", "5", "7", "9"])
407
- self.BrushChoose.currentIndexChanged.connect(self.brush_choose)
408
- self.BrushChoose.setFixedWidth(40)
409
- self.BrushChoose.setFont(self.medfont)
410
- self.drawBoxG.addWidget(self.BrushChoose, widget_row, 3, 1, 2)
411
- label = QLabel("brush size:")
412
- label.setFont(self.medfont)
413
- self.drawBoxG.addWidget(label, widget_row, 0, 1, 3)
414
-
415
- widget_row += 1
416
- # turn off masks
417
- self.layer_off = False
418
- self.masksOn = True
419
- self.MCheckBox = QCheckBox("MASKS ON [X]")
420
- self.MCheckBox.setFont(self.medfont)
421
- self.MCheckBox.setChecked(True)
422
- self.MCheckBox.toggled.connect(self.toggle_masks)
423
- self.drawBoxG.addWidget(self.MCheckBox, widget_row, 0, 1, 5)
424
-
425
- widget_row += 1
426
- # turn off outlines
427
- self.outlinesOn = False # turn off by default
428
- self.OCheckBox = QCheckBox("outlines on [Z]")
429
- self.OCheckBox.setFont(self.medfont)
430
- self.drawBoxG.addWidget(self.OCheckBox, widget_row, 0, 1, 5)
431
- self.OCheckBox.setChecked(False)
432
- self.OCheckBox.toggled.connect(self.toggle_masks)
433
-
434
- widget_row += 1
435
- self.SCheckBox = QCheckBox("single stroke")
436
- self.SCheckBox.setFont(self.medfont)
437
- self.SCheckBox.setChecked(True)
438
- self.SCheckBox.toggled.connect(self.autosave_on)
439
- self.SCheckBox.setEnabled(True)
440
- self.drawBoxG.addWidget(self.SCheckBox, widget_row, 0, 1, 5)
441
-
442
- # buttons for deleting multiple cells
443
- self.deleteBox = QGroupBox("delete multiple ROIs")
444
- self.deleteBox.setStyleSheet("color: rgb(200, 200, 200)")
445
- self.deleteBox.setFont(self.medfont)
446
- self.deleteBoxG = QGridLayout()
447
- self.deleteBox.setLayout(self.deleteBoxG)
448
- self.drawBoxG.addWidget(self.deleteBox, 0, 5, 4, 4)
449
- self.MakeDeletionRegionButton = QPushButton("region-select")
450
- self.MakeDeletionRegionButton.clicked.connect(self.remove_region_cells)
451
- self.deleteBoxG.addWidget(self.MakeDeletionRegionButton, 0, 0, 1, 4)
452
- self.MakeDeletionRegionButton.setFont(self.smallfont)
453
- self.MakeDeletionRegionButton.setFixedWidth(70)
454
- self.DeleteMultipleROIButton = QPushButton("click-select")
455
- self.DeleteMultipleROIButton.clicked.connect(self.delete_multiple_cells)
456
- self.deleteBoxG.addWidget(self.DeleteMultipleROIButton, 1, 0, 1, 4)
457
- self.DeleteMultipleROIButton.setFont(self.smallfont)
458
- self.DeleteMultipleROIButton.setFixedWidth(70)
459
- self.DoneDeleteMultipleROIButton = QPushButton("done")
460
- self.DoneDeleteMultipleROIButton.clicked.connect(
461
- self.done_remove_multiple_cells)
462
- self.deleteBoxG.addWidget(self.DoneDeleteMultipleROIButton, 2, 0, 1, 2)
463
- self.DoneDeleteMultipleROIButton.setFont(self.smallfont)
464
- self.DoneDeleteMultipleROIButton.setFixedWidth(35)
465
- self.CancelDeleteMultipleROIButton = QPushButton("cancel")
466
- self.CancelDeleteMultipleROIButton.clicked.connect(self.cancel_remove_multiple)
467
- self.deleteBoxG.addWidget(self.CancelDeleteMultipleROIButton, 2, 2, 1, 2)
468
- self.CancelDeleteMultipleROIButton.setFont(self.smallfont)
469
- self.CancelDeleteMultipleROIButton.setFixedWidth(35)
470
-
471
- b += 1
472
- widget_row = 0
473
- self.segBox = QGroupBox("Segmentation")
474
- self.segBoxG = QGridLayout()
475
- self.segBox.setLayout(self.segBoxG)
476
- self.l0.addWidget(self.segBox, b, 0, 1, 9)
477
- self.segBox.setFont(self.boldfont)
478
-
479
- widget_row += 1
480
-
481
- # use GPU
482
- self.useGPU = QCheckBox("use GPU")
483
- self.useGPU.setToolTip(
484
- "if you have specially installed the <i>cuda</i> version of torch, then you can activate this"
485
- )
486
- self.useGPU.setFont(self.medfont)
487
- self.check_gpu()
488
- self.segBoxG.addWidget(self.useGPU, widget_row, 0, 1, 3)
489
-
490
- # compute segmentation with general models
491
- self.net_text = ["run CPSAM"]
492
- nett = ["cellpose super-generalist model"]
493
-
494
- self.StyleButtons = []
495
- jj = 4
496
- for j in range(len(self.net_text)):
497
- self.StyleButtons.append(
498
- guiparts.ModelButton(self, self.net_text[j], self.net_text[j]))
499
- w = 5
500
- self.segBoxG.addWidget(self.StyleButtons[-1], widget_row, jj, 1, w)
501
- jj += w
502
- self.StyleButtons[-1].setToolTip(nett[j])
503
-
504
- widget_row += 1
505
- self.ncells = guiparts.ObservableVariable(0)
506
- self.roi_count = QLabel()
507
- self.roi_count.setFont(self.boldfont)
508
- self.roi_count.setAlignment(QtCore.Qt.AlignLeft)
509
- self.ncells.valueChanged.connect(
510
- lambda n: self.roi_count.setText(f'{str(n)} ROIs')
511
- )
512
-
513
- self.segBoxG.addWidget(self.roi_count, widget_row, 0, 1, 4)
514
-
515
- self.progress = QProgressBar(self)
516
- self.segBoxG.addWidget(self.progress, widget_row, 4, 1, 5)
517
-
518
- widget_row += 1
519
-
520
- ############################### Segmentation settings ###############################
521
- self.additional_seg_settings_qcollapsible = QCollapsible("additional settings")
522
- self.additional_seg_settings_qcollapsible.setFont(self.medfont)
523
- self.additional_seg_settings_qcollapsible._toggle_btn.setFont(self.medfont)
524
- self.segmentation_settings = guiparts.SegmentationSettings(self.medfont)
525
- self.additional_seg_settings_qcollapsible.setContent(self.segmentation_settings)
526
- self.segBoxG.addWidget(self.additional_seg_settings_qcollapsible, widget_row, 0, 1, 9)
527
-
528
- # connect edits to image processing steps:
529
- self.segmentation_settings.diameter_box.editingFinished.connect(self.update_scale)
530
- self.segmentation_settings.flow_threshold_box.returnPressed.connect(self.compute_cprob)
531
- self.segmentation_settings.cellprob_threshold_box.returnPressed.connect(self.compute_cprob)
532
- self.segmentation_settings.niter_box.returnPressed.connect(self.compute_cprob)
533
-
534
- # Needed to do this for the drop down to not be open on startup
535
- self.additional_seg_settings_qcollapsible._toggle_btn.setChecked(True)
536
- self.additional_seg_settings_qcollapsible._toggle_btn.setChecked(False)
537
-
538
- b += 1
539
- self.modelBox = QGroupBox("user-trained models")
540
- self.modelBoxG = QGridLayout()
541
- self.modelBox.setLayout(self.modelBoxG)
542
- self.l0.addWidget(self.modelBox, b, 0, 1, 9)
543
- self.modelBox.setFont(self.boldfont)
544
- # choose models
545
- self.ModelChooseC = QComboBox()
546
- self.ModelChooseC.setFont(self.medfont)
547
- current_index = 0
548
- self.ModelChooseC.addItems(["custom models"])
549
- if len(self.model_strings) > 0:
550
- self.ModelChooseC.addItems(self.model_strings)
551
- self.ModelChooseC.setFixedWidth(175)
552
- self.ModelChooseC.setCurrentIndex(current_index)
553
- tipstr = 'add or train your own models in the "Models" file menu and choose model here'
554
- self.ModelChooseC.setToolTip(tipstr)
555
- self.ModelChooseC.activated.connect(lambda: self.model_choose(custom=True))
556
- self.modelBoxG.addWidget(self.ModelChooseC, widget_row, 0, 1, 8)
557
-
558
- # compute segmentation w/ custom model
559
- self.ModelButtonC = QPushButton(u"run")
560
- self.ModelButtonC.setFont(self.medfont)
561
- self.ModelButtonC.setFixedWidth(35)
562
- self.ModelButtonC.clicked.connect(
563
- lambda: self.compute_segmentation(custom=True))
564
- self.modelBoxG.addWidget(self.ModelButtonC, widget_row, 8, 1, 1)
565
- self.ModelButtonC.setEnabled(False)
566
-
567
-
568
- b += 1
569
- self.filterBox = QGroupBox("Image filtering")
570
- self.filterBox.setFont(self.boldfont)
571
- self.filterBox_grid_layout = QGridLayout()
572
- self.filterBox.setLayout(self.filterBox_grid_layout)
573
- self.l0.addWidget(self.filterBox, b, 0, 1, 9)
574
-
575
- widget_row = 0
576
-
577
- # Filtering
578
- self.FilterButtons = []
579
- nett = [
580
- "clear restore/filter",
581
- "filter image (settings below)",
582
- ]
583
- self.filter_text = ["none",
584
- "filter",
585
- ]
586
- self.restore = None
587
- self.ratio = 1.
588
- jj = 0
589
- w = 3
590
- for j in range(len(self.filter_text)):
591
- self.FilterButtons.append(
592
- guiparts.FilterButton(self, self.filter_text[j]))
593
- self.filterBox_grid_layout.addWidget(self.FilterButtons[-1], widget_row, jj, 1, w)
594
- self.FilterButtons[-1].setFixedWidth(75)
595
- self.FilterButtons[-1].setToolTip(nett[j])
596
- self.FilterButtons[-1].setFont(self.medfont)
597
- widget_row += 1 if j%2==1 else 0
598
- jj = 0 if j%2==1 else jj + w
599
-
600
- self.save_norm = QCheckBox("save restored/filtered image")
601
- self.save_norm.setFont(self.medfont)
602
- self.save_norm.setToolTip("save restored/filtered image in _seg.npy file")
603
- self.save_norm.setChecked(True)
604
-
605
- widget_row += 2
606
-
607
- self.filtBox = QCollapsible("custom filter settings")
608
- self.filtBox._toggle_btn.setFont(self.medfont)
609
- self.filtBoxG = QGridLayout()
610
- _content = QWidget()
611
- _content.setLayout(self.filtBoxG)
612
- _content.setMaximumHeight(0)
613
- _content.setMinimumHeight(0)
614
- self.filtBox.setContent(_content)
615
- self.filterBox_grid_layout.addWidget(self.filtBox, widget_row, 0, 1, 9)
616
-
617
- self.filt_vals = [0., 0., 0., 0.]
618
- self.filt_edits = []
619
- labels = [
620
- "sharpen\nradius", "smooth\nradius", "tile_norm\nblocksize",
621
- "tile_norm\nsmooth3D"
622
- ]
623
- tooltips = [
624
- "set size of surround-subtraction filter for sharpening image",
625
- "set size of gaussian filter for smoothing image",
626
- "set size of tiles to use to normalize image",
627
- "set amount of smoothing of normalization values across planes"
628
- ]
629
-
630
- for p in range(4):
631
- label = QLabel(f"{labels[p]}:")
632
- label.setToolTip(tooltips[p])
633
- label.setFont(self.medfont)
634
- self.filtBoxG.addWidget(label, widget_row + p // 2, 4 * (p % 2), 1, 2)
635
- self.filt_edits.append(QLineEdit())
636
- self.filt_edits[p].setText(str(self.filt_vals[p]))
637
- self.filt_edits[p].setFixedWidth(40)
638
- self.filt_edits[p].setFont(self.medfont)
639
- self.filtBoxG.addWidget(self.filt_edits[p], widget_row + p // 2, 4 * (p % 2) + 2, 1,
640
- 2)
641
- self.filt_edits[p].setToolTip(tooltips[p])
642
-
643
- widget_row += 3
644
- self.norm3D_cb = QCheckBox("norm3D")
645
- self.norm3D_cb.setFont(self.medfont)
646
- self.norm3D_cb.setChecked(True)
647
- self.norm3D_cb.setToolTip("run same normalization across planes")
648
- self.filtBoxG.addWidget(self.norm3D_cb, widget_row, 0, 1, 3)
649
-
650
-
651
- return b
652
-
653
- def level_change(self, r):
654
- r = ["red", "green", "blue"].index(r)
655
- if self.loaded:
656
- sval = self.sliders[r].value()
657
- self.saturation[r][self.currentZ] = sval
658
- if not self.autobtn.isChecked():
659
- for r in range(3):
660
- for i in range(len(self.saturation[r])):
661
- self.saturation[r][i] = self.saturation[r][self.currentZ]
662
- self.update_plot()
663
-
664
- def keyPressEvent(self, event):
665
- if self.loaded:
666
- if not (event.modifiers() &
667
- (QtCore.Qt.ControlModifier | QtCore.Qt.ShiftModifier |
668
- QtCore.Qt.AltModifier) or self.in_stroke):
669
- updated = False
670
- if len(self.current_point_set) > 0:
671
- if event.key() == QtCore.Qt.Key_Return:
672
- self.add_set()
673
- else:
674
- nviews = self.ViewDropDown.count() - 1
675
- nviews += int(
676
- self.ViewDropDown.model().item(self.ViewDropDown.count() -
677
- 1).isEnabled())
678
- if event.key() == QtCore.Qt.Key_X:
679
- self.MCheckBox.toggle()
680
- if event.key() == QtCore.Qt.Key_Z:
681
- self.OCheckBox.toggle()
682
- if event.key() == QtCore.Qt.Key_Left or event.key(
683
- ) == QtCore.Qt.Key_A:
684
- self.get_prev_image()
685
- elif event.key() == QtCore.Qt.Key_Right or event.key(
686
- ) == QtCore.Qt.Key_D:
687
- self.get_next_image()
688
- elif event.key() == QtCore.Qt.Key_PageDown:
689
- self.view = (self.view + 1) % (nviews)
690
- self.ViewDropDown.setCurrentIndex(self.view)
691
- elif event.key() == QtCore.Qt.Key_PageUp:
692
- self.view = (self.view - 1) % (nviews)
693
- self.ViewDropDown.setCurrentIndex(self.view)
694
-
695
- # can change background or stroke size if cell not finished
696
- if event.key() == QtCore.Qt.Key_Up or event.key() == QtCore.Qt.Key_W:
697
- self.color = (self.color - 1) % (6)
698
- self.RGBDropDown.setCurrentIndex(self.color)
699
- elif event.key() == QtCore.Qt.Key_Down or event.key(
700
- ) == QtCore.Qt.Key_S:
701
- self.color = (self.color + 1) % (6)
702
- self.RGBDropDown.setCurrentIndex(self.color)
703
- elif event.key() == QtCore.Qt.Key_R:
704
- if self.color != 1:
705
- self.color = 1
706
- else:
707
- self.color = 0
708
- self.RGBDropDown.setCurrentIndex(self.color)
709
- elif event.key() == QtCore.Qt.Key_G:
710
- if self.color != 2:
711
- self.color = 2
712
- else:
713
- self.color = 0
714
- self.RGBDropDown.setCurrentIndex(self.color)
715
- elif event.key() == QtCore.Qt.Key_B:
716
- if self.color != 3:
717
- self.color = 3
718
- else:
719
- self.color = 0
720
- self.RGBDropDown.setCurrentIndex(self.color)
721
- elif (event.key() == QtCore.Qt.Key_Comma or
722
- event.key() == QtCore.Qt.Key_Period):
723
- count = self.BrushChoose.count()
724
- gci = self.BrushChoose.currentIndex()
725
- if event.key() == QtCore.Qt.Key_Comma:
726
- gci = max(0, gci - 1)
727
- else:
728
- gci = min(count - 1, gci + 1)
729
- self.BrushChoose.setCurrentIndex(gci)
730
- self.brush_choose()
731
- if not updated:
732
- self.update_plot()
733
- if event.key() == QtCore.Qt.Key_Minus or event.key() == QtCore.Qt.Key_Equal:
734
- self.p0.keyPressEvent(event)
735
-
736
- def autosave_on(self):
737
- if self.SCheckBox.isChecked():
738
- self.autosave = True
739
- else:
740
- self.autosave = False
741
-
742
- def check_gpu(self, torch=True):
743
- # also decide whether or not to use torch
744
- self.useGPU.setChecked(False)
745
- self.useGPU.setEnabled(False)
746
- if core.use_gpu(use_torch=True):
747
- self.useGPU.setEnabled(True)
748
- self.useGPU.setChecked(True)
749
- else:
750
- self.useGPU.setStyleSheet("color: rgb(80,80,80);")
751
-
752
-
753
- def model_choose(self, custom=False):
754
- index = self.ModelChooseC.currentIndex(
755
- ) if custom else self.ModelChooseB.currentIndex()
756
- if index > 0:
757
- if custom:
758
- model_name = self.ModelChooseC.currentText()
759
- else:
760
- model_name = self.net_names[index - 1]
761
- print(f"GUI_INFO: selected model {model_name}, loading now")
762
- self.initialize_model(model_name=model_name, custom=custom)
763
-
764
- def toggle_scale(self):
765
- if self.scale_on:
766
- self.p0.removeItem(self.scale)
767
- self.scale_on = False
768
- else:
769
- self.p0.addItem(self.scale)
770
- self.scale_on = True
771
-
772
- def enable_buttons(self):
773
- if len(self.model_strings) > 0:
774
- self.ModelButtonC.setEnabled(True)
775
- for i in range(len(self.StyleButtons)):
776
- self.StyleButtons[i].setEnabled(True)
777
-
778
- for i in range(len(self.FilterButtons)):
779
- self.FilterButtons[i].setEnabled(True)
780
- if self.load_3D:
781
- self.FilterButtons[-2].setEnabled(False)
782
-
783
- self.newmodel.setEnabled(True)
784
- self.loadMasks.setEnabled(True)
785
-
786
- for n in range(self.nchan):
787
- self.sliders[n].setEnabled(True)
788
- for n in range(self.nchan, 3):
789
- self.sliders[n].setEnabled(True)
790
-
791
- self.toggle_mask_ops()
792
-
793
- self.update_plot()
794
- self.setWindowTitle(self.filename)
795
-
796
- def disable_buttons_removeROIs(self):
797
- if len(self.model_strings) > 0:
798
- self.ModelButtonC.setEnabled(False)
799
- for i in range(len(self.StyleButtons)):
800
- self.StyleButtons[i].setEnabled(False)
801
- self.newmodel.setEnabled(False)
802
- self.loadMasks.setEnabled(False)
803
- self.saveSet.setEnabled(False)
804
- self.savePNG.setEnabled(False)
805
- self.saveFlows.setEnabled(False)
806
- self.saveOutlines.setEnabled(False)
807
- self.saveROIs.setEnabled(False)
808
-
809
- self.MakeDeletionRegionButton.setEnabled(False)
810
- self.DeleteMultipleROIButton.setEnabled(False)
811
- self.DoneDeleteMultipleROIButton.setEnabled(True)
812
- self.CancelDeleteMultipleROIButton.setEnabled(True)
813
-
814
- def toggle_mask_ops(self):
815
- self.update_layer()
816
- self.toggle_saving()
817
- self.toggle_removals()
818
-
819
- def toggle_saving(self):
820
- if self.ncells > 0:
821
- self.saveSet.setEnabled(True)
822
- self.savePNG.setEnabled(True)
823
- self.saveFlows.setEnabled(True)
824
- self.saveOutlines.setEnabled(True)
825
- self.saveROIs.setEnabled(True)
826
- else:
827
- self.saveSet.setEnabled(False)
828
- self.savePNG.setEnabled(False)
829
- self.saveFlows.setEnabled(False)
830
- self.saveOutlines.setEnabled(False)
831
- self.saveROIs.setEnabled(False)
832
-
833
- def toggle_removals(self):
834
- if self.ncells > 0:
835
- self.ClearButton.setEnabled(True)
836
- self.remcell.setEnabled(True)
837
- self.undo.setEnabled(True)
838
- self.MakeDeletionRegionButton.setEnabled(True)
839
- self.DeleteMultipleROIButton.setEnabled(True)
840
- self.DoneDeleteMultipleROIButton.setEnabled(False)
841
- self.CancelDeleteMultipleROIButton.setEnabled(False)
842
- else:
843
- self.ClearButton.setEnabled(False)
844
- self.remcell.setEnabled(False)
845
- self.undo.setEnabled(False)
846
- self.MakeDeletionRegionButton.setEnabled(False)
847
- self.DeleteMultipleROIButton.setEnabled(False)
848
- self.DoneDeleteMultipleROIButton.setEnabled(False)
849
- self.CancelDeleteMultipleROIButton.setEnabled(False)
850
-
851
- def remove_action(self):
852
- if self.selected > 0:
853
- self.remove_cell(self.selected)
854
-
855
- def undo_action(self):
856
- if (len(self.strokes) > 0 and self.strokes[-1][0][0] == self.currentZ):
857
- self.remove_stroke()
858
- else:
859
- # remove previous cell
860
- if self.ncells > 0:
861
- self.remove_cell(self.ncells.get())
862
-
863
- def undo_remove_action(self):
864
- self.undo_remove_cell()
865
-
866
- def get_files(self):
867
- folder = os.path.dirname(self.filename)
868
- mask_filter = "_masks"
869
- images = get_image_files(folder, mask_filter)
870
- fnames = [os.path.split(images[k])[-1] for k in range(len(images))]
871
- f0 = os.path.split(self.filename)[-1]
872
- idx = np.nonzero(np.array(fnames) == f0)[0][0]
873
- return images, idx
874
-
875
- def get_prev_image(self):
876
- images, idx = self.get_files()
877
- idx = (idx - 1) % len(images)
878
- io._load_image(self, filename=images[idx])
879
-
880
- def get_next_image(self, load_seg=True):
881
- images, idx = self.get_files()
882
- idx = (idx + 1) % len(images)
883
- io._load_image(self, filename=images[idx], load_seg=load_seg)
884
-
885
- def dragEnterEvent(self, event):
886
- if event.mimeData().hasUrls():
887
- event.accept()
888
- else:
889
- event.ignore()
890
-
891
- def dropEvent(self, event):
892
- files = [u.toLocalFile() for u in event.mimeData().urls()]
893
- if os.path.splitext(files[0])[-1] == ".npy":
894
- io._load_seg(self, filename=files[0], load_3D=self.load_3D)
895
- else:
896
- io._load_image(self, filename=files[0], load_seg=True, load_3D=self.load_3D)
897
-
898
- def toggle_masks(self):
899
- if self.MCheckBox.isChecked():
900
- self.masksOn = True
901
- else:
902
- self.masksOn = False
903
- if self.OCheckBox.isChecked():
904
- self.outlinesOn = True
905
- else:
906
- self.outlinesOn = False
907
- if not self.masksOn and not self.outlinesOn:
908
- self.p0.removeItem(self.layer)
909
- self.layer_off = True
910
- else:
911
- if self.layer_off:
912
- self.p0.addItem(self.layer)
913
- self.draw_layer()
914
- self.update_layer()
915
- if self.loaded:
916
- self.update_plot()
917
- self.update_layer()
918
-
919
- def make_viewbox(self):
920
- self.p0 = guiparts.ViewBoxNoRightDrag(parent=self, lockAspect=True,
921
- name="plot1", border=[100, 100,
922
- 100], invertY=True)
923
- self.p0.setCursor(QtCore.Qt.CrossCursor)
924
- self.brush_size = 3
925
- self.win.addItem(self.p0, 0, 0, rowspan=1, colspan=1)
926
- self.p0.setMenuEnabled(False)
927
- self.p0.setMouseEnabled(x=True, y=True)
928
- self.img = pg.ImageItem(viewbox=self.p0, parent=self)
929
- self.img.autoDownsample = False
930
- self.layer = guiparts.ImageDraw(viewbox=self.p0, parent=self)
931
- self.layer.setLevels([0, 255])
932
- self.scale = pg.ImageItem(viewbox=self.p0, parent=self)
933
- self.scale.setLevels([0, 255])
934
- self.p0.scene().contextMenuItem = self.p0
935
- self.Ly, self.Lx = 512, 512
936
- self.p0.addItem(self.img)
937
- self.p0.addItem(self.layer)
938
- self.p0.addItem(self.scale)
939
-
940
- def reset(self):
941
- # ---- start sets of points ---- #
942
- self.selected = 0
943
- self.nchan = 3
944
- self.loaded = False
945
- self.channel = [0, 1]
946
- self.current_point_set = []
947
- self.in_stroke = False
948
- self.strokes = []
949
- self.stroke_appended = True
950
- self.resize = False
951
- self.ncells.reset()
952
- self.zdraw = []
953
- self.removed_cell = []
954
- self.cellcolors = np.array([255, 255, 255])[np.newaxis, :]
955
-
956
- # -- zero out image stack -- #
957
- self.opacity = 128 # how opaque masks should be
958
- self.outcolor = [200, 200, 255, 200]
959
- self.NZ, self.Ly, self.Lx = 1, 256, 256
960
- self.saturation = self.saturation if hasattr(self, 'saturation') else []
961
-
962
- # only adjust the saturation if auto-adjust is on:
963
- if self.autobtn.isChecked():
964
- for r in range(3):
965
- self.saturation.append([[0, 255] for n in range(self.NZ)])
966
- self.sliders[r].setValue([0, 255])
967
- self.sliders[r].setEnabled(False)
968
- self.sliders[r].show()
969
- self.currentZ = 0
970
- self.flows = [[], [], [], [], [[]]]
971
- # masks matrix
972
- # image matrix with a scale disk
973
- self.stack = np.zeros((1, self.Ly, self.Lx, 3))
974
- self.Lyr, self.Lxr = self.Ly, self.Lx
975
- self.Ly0, self.Lx0 = self.Ly, self.Lx
976
- self.radii = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8)
977
- self.layerz = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8)
978
- self.cellpix = np.zeros((1, self.Ly, self.Lx), np.uint16)
979
- self.outpix = np.zeros((1, self.Ly, self.Lx), np.uint16)
980
- self.ismanual = np.zeros(0, "bool")
981
-
982
- # -- set menus to default -- #
983
- self.color = 0
984
- self.RGBDropDown.setCurrentIndex(self.color)
985
- self.view = 0
986
- self.ViewDropDown.setCurrentIndex(0)
987
- self.ViewDropDown.model().item(self.ViewDropDown.count() - 1).setEnabled(False)
988
- self.delete_restore()
989
-
990
- self.clear_all()
991
-
992
- self.filename = []
993
- self.loaded = False
994
- self.recompute_masks = False
995
-
996
- self.deleting_multiple = False
997
- self.removing_cells_list = []
998
- self.removing_region = False
999
- self.remove_roi_obj = None
1000
-
1001
- def delete_restore(self):
1002
- """ delete restored imgs but don't reset settings """
1003
- if hasattr(self, "stack_filtered"):
1004
- del self.stack_filtered
1005
- if hasattr(self, "cellpix_orig"):
1006
- self.cellpix = self.cellpix_orig.copy()
1007
- self.outpix = self.outpix_orig.copy()
1008
- del self.outpix_orig, self.outpix_resize
1009
- del self.cellpix_orig, self.cellpix_resize
1010
-
1011
- def clear_restore(self):
1012
- """ delete restored imgs and reset settings """
1013
- print("GUI_INFO: clearing restored image")
1014
- self.ViewDropDown.model().item(self.ViewDropDown.count() - 1).setEnabled(False)
1015
- if self.ViewDropDown.currentIndex() == self.ViewDropDown.count() - 1:
1016
- self.ViewDropDown.setCurrentIndex(0)
1017
- self.delete_restore()
1018
- self.restore = None
1019
- self.ratio = 1.
1020
- self.set_normalize_params(self.get_normalize_params())
1021
-
1022
- def brush_choose(self):
1023
- self.brush_size = self.BrushChoose.currentIndex() * 2 + 1
1024
- if self.loaded:
1025
- self.layer.setDrawKernel(kernel_size=self.brush_size)
1026
- self.update_layer()
1027
-
1028
- def clear_all(self):
1029
- self.prev_selected = 0
1030
- self.selected = 0
1031
- if self.restore and "upsample" in self.restore:
1032
- self.layerz = 0 * np.ones((self.Lyr, self.Lxr, 4), np.uint8)
1033
- self.cellpix = np.zeros((self.NZ, self.Lyr, self.Lxr), np.uint16)
1034
- self.outpix = np.zeros((self.NZ, self.Lyr, self.Lxr), np.uint16)
1035
- self.cellpix_resize = self.cellpix.copy()
1036
- self.outpix_resize = self.outpix.copy()
1037
- self.cellpix_orig = np.zeros((self.NZ, self.Ly0, self.Lx0), np.uint16)
1038
- self.outpix_orig = np.zeros((self.NZ, self.Ly0, self.Lx0), np.uint16)
1039
- else:
1040
- self.layerz = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8)
1041
- self.cellpix = np.zeros((self.NZ, self.Ly, self.Lx), np.uint16)
1042
- self.outpix = np.zeros((self.NZ, self.Ly, self.Lx), np.uint16)
1043
-
1044
- self.cellcolors = np.array([255, 255, 255])[np.newaxis, :]
1045
- self.ncells.reset()
1046
- self.toggle_removals()
1047
- self.update_scale()
1048
- self.update_layer()
1049
-
1050
- def select_cell(self, idx):
1051
- self.prev_selected = self.selected
1052
- self.selected = idx
1053
- if self.selected > 0:
1054
- z = self.currentZ
1055
- self.layerz[self.cellpix[z] == idx] = np.array(
1056
- [255, 255, 255, self.opacity])
1057
- self.update_layer()
1058
-
1059
- def select_cell_multi(self, idx):
1060
- if idx > 0:
1061
- z = self.currentZ
1062
- self.layerz[self.cellpix[z] == idx] = np.array(
1063
- [255, 255, 255, self.opacity])
1064
- self.update_layer()
1065
-
1066
- def unselect_cell(self):
1067
- if self.selected > 0:
1068
- idx = self.selected
1069
- if idx < (self.ncells.get() + 1):
1070
- z = self.currentZ
1071
- self.layerz[self.cellpix[z] == idx] = np.append(
1072
- self.cellcolors[idx], self.opacity)
1073
- if self.outlinesOn:
1074
- self.layerz[self.outpix[z] == idx] = np.array(self.outcolor).astype(
1075
- np.uint8)
1076
- #[0,0,0,self.opacity])
1077
- self.update_layer()
1078
- self.selected = 0
1079
-
1080
- def unselect_cell_multi(self, idx):
1081
- z = self.currentZ
1082
- self.layerz[self.cellpix[z] == idx] = np.append(self.cellcolors[idx],
1083
- self.opacity)
1084
- if self.outlinesOn:
1085
- self.layerz[self.outpix[z] == idx] = np.array(self.outcolor).astype(
1086
- np.uint8)
1087
- # [0,0,0,self.opacity])
1088
- self.update_layer()
1089
-
1090
- def remove_cell(self, idx):
1091
- if isinstance(idx, (int, np.integer)):
1092
- idx = [idx]
1093
- # because the function remove_single_cell updates the state of the cellpix and outpix arrays
1094
- # by reindexing cells to avoid gaps in the indices, we need to remove the cells in reverse order
1095
- # so that the indices are correct
1096
- idx.sort(reverse=True)
1097
- for i in idx:
1098
- self.remove_single_cell(i)
1099
- self.ncells -= len(idx) # _save_sets uses ncells
1100
- self.update_layer()
1101
-
1102
- if self.ncells == 0:
1103
- self.ClearButton.setEnabled(False)
1104
- if self.NZ == 1:
1105
- io._save_sets_with_check(self)
1106
-
1107
-
1108
- def remove_single_cell(self, idx):
1109
- # remove from manual array
1110
- self.selected = 0
1111
- if self.NZ > 1:
1112
- zextent = ((self.cellpix == idx).sum(axis=(1, 2)) > 0).nonzero()[0]
1113
- else:
1114
- zextent = [0]
1115
- for z in zextent:
1116
- cp = self.cellpix[z] == idx
1117
- op = self.outpix[z] == idx
1118
- # remove from self.cellpix and self.outpix
1119
- self.cellpix[z, cp] = 0
1120
- self.outpix[z, op] = 0
1121
- if z == self.currentZ:
1122
- # remove from mask layer
1123
- self.layerz[cp] = np.array([0, 0, 0, 0])
1124
-
1125
- # reduce other pixels by -1
1126
- self.cellpix[self.cellpix > idx] -= 1
1127
- self.outpix[self.outpix > idx] -= 1
1128
-
1129
- if self.NZ == 1:
1130
- self.removed_cell = [
1131
- self.ismanual[idx - 1], self.cellcolors[idx],
1132
- np.nonzero(cp),
1133
- np.nonzero(op)
1134
- ]
1135
- self.redo.setEnabled(True)
1136
- ar, ac = self.removed_cell[2]
1137
- d = datetime.datetime.now()
1138
- self.track_changes.append(
1139
- [d.strftime("%m/%d/%Y, %H:%M:%S"), "removed mask", [ar, ac]])
1140
- # remove cell from lists
1141
- self.ismanual = np.delete(self.ismanual, idx - 1)
1142
- self.cellcolors = np.delete(self.cellcolors, [idx], axis=0)
1143
- del self.zdraw[idx - 1]
1144
- print("GUI_INFO: removed cell %d" % (idx - 1))
1145
-
1146
- def remove_region_cells(self):
1147
- if self.removing_cells_list:
1148
- for idx in self.removing_cells_list:
1149
- self.unselect_cell_multi(idx)
1150
- self.removing_cells_list.clear()
1151
- self.disable_buttons_removeROIs()
1152
- self.removing_region = True
1153
-
1154
- self.clear_multi_selected_cells()
1155
-
1156
- # make roi region here in center of view, making ROI half the size of the view
1157
- roi_width = self.p0.viewRect().width() / 2
1158
- x_loc = self.p0.viewRect().x() + (roi_width / 2)
1159
- roi_height = self.p0.viewRect().height() / 2
1160
- y_loc = self.p0.viewRect().y() + (roi_height / 2)
1161
-
1162
- pos = [x_loc, y_loc]
1163
- roi = pg.RectROI(pos, [roi_width, roi_height], pen=pg.mkPen("y", width=2),
1164
- removable=True)
1165
- roi.sigRemoveRequested.connect(self.remove_roi)
1166
- roi.sigRegionChangeFinished.connect(self.roi_changed)
1167
- self.p0.addItem(roi)
1168
- self.remove_roi_obj = roi
1169
- self.roi_changed(roi)
1170
-
1171
- def delete_multiple_cells(self):
1172
- self.unselect_cell()
1173
- self.disable_buttons_removeROIs()
1174
- self.DoneDeleteMultipleROIButton.setEnabled(True)
1175
- self.MakeDeletionRegionButton.setEnabled(True)
1176
- self.CancelDeleteMultipleROIButton.setEnabled(True)
1177
- self.deleting_multiple = True
1178
-
1179
- def done_remove_multiple_cells(self):
1180
- self.deleting_multiple = False
1181
- self.removing_region = False
1182
- self.DoneDeleteMultipleROIButton.setEnabled(False)
1183
- self.MakeDeletionRegionButton.setEnabled(False)
1184
- self.CancelDeleteMultipleROIButton.setEnabled(False)
1185
-
1186
- if self.removing_cells_list:
1187
- self.removing_cells_list = list(set(self.removing_cells_list))
1188
- display_remove_list = [i - 1 for i in self.removing_cells_list]
1189
- print(f"GUI_INFO: removing cells: {display_remove_list}")
1190
- self.remove_cell(self.removing_cells_list)
1191
- self.removing_cells_list.clear()
1192
- self.unselect_cell()
1193
- self.enable_buttons()
1194
-
1195
- if self.remove_roi_obj is not None:
1196
- self.remove_roi(self.remove_roi_obj)
1197
-
1198
- def merge_cells(self, idx):
1199
- self.prev_selected = self.selected
1200
- self.selected = idx
1201
- if self.selected != self.prev_selected:
1202
- for z in range(self.NZ):
1203
- ar0, ac0 = np.nonzero(self.cellpix[z] == self.prev_selected)
1204
- ar1, ac1 = np.nonzero(self.cellpix[z] == self.selected)
1205
- touching = np.logical_and((ar0[:, np.newaxis] - ar1) < 3,
1206
- (ac0[:, np.newaxis] - ac1) < 3).sum()
1207
- ar = np.hstack((ar0, ar1))
1208
- ac = np.hstack((ac0, ac1))
1209
- vr0, vc0 = np.nonzero(self.outpix[z] == self.prev_selected)
1210
- vr1, vc1 = np.nonzero(self.outpix[z] == self.selected)
1211
- self.outpix[z, vr0, vc0] = 0
1212
- self.outpix[z, vr1, vc1] = 0
1213
- if touching > 0:
1214
- mask = np.zeros((np.ptp(ar) + 4, np.ptp(ac) + 4), np.uint8)
1215
- mask[ar - ar.min() + 2, ac - ac.min() + 2] = 1
1216
- contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
1217
- cv2.CHAIN_APPROX_NONE)
1218
- pvc, pvr = contours[-2][0].squeeze().T
1219
- vr, vc = pvr + ar.min() - 2, pvc + ac.min() - 2
1220
-
1221
- else:
1222
- vr = np.hstack((vr0, vr1))
1223
- vc = np.hstack((vc0, vc1))
1224
- color = self.cellcolors[self.prev_selected]
1225
- self.draw_mask(z, ar, ac, vr, vc, color, idx=self.prev_selected)
1226
- self.remove_cell(self.selected)
1227
- print("GUI_INFO: merged two cells")
1228
- self.update_layer()
1229
- io._save_sets_with_check(self)
1230
- self.undo.setEnabled(False)
1231
- self.redo.setEnabled(False)
1232
-
1233
- def undo_remove_cell(self):
1234
- if len(self.removed_cell) > 0:
1235
- z = 0
1236
- ar, ac = self.removed_cell[2]
1237
- vr, vc = self.removed_cell[3]
1238
- color = self.removed_cell[1]
1239
- self.draw_mask(z, ar, ac, vr, vc, color)
1240
- self.toggle_mask_ops()
1241
- self.cellcolors = np.append(self.cellcolors, color[np.newaxis, :], axis=0)
1242
- self.ncells += 1
1243
- self.ismanual = np.append(self.ismanual, self.removed_cell[0])
1244
- self.zdraw.append([])
1245
- print(">>> added back removed cell")
1246
- self.update_layer()
1247
- io._save_sets_with_check(self)
1248
- self.removed_cell = []
1249
- self.redo.setEnabled(False)
1250
-
1251
- def remove_stroke(self, delete_points=True, stroke_ind=-1):
1252
- stroke = np.array(self.strokes[stroke_ind])
1253
- cZ = self.currentZ
1254
- inZ = stroke[0, 0] == cZ
1255
- if inZ:
1256
- outpix = self.outpix[cZ, stroke[:, 1], stroke[:, 2]] > 0
1257
- self.layerz[stroke[~outpix, 1], stroke[~outpix, 2]] = np.array([0, 0, 0, 0])
1258
- cellpix = self.cellpix[cZ, stroke[:, 1], stroke[:, 2]]
1259
- ccol = self.cellcolors.copy()
1260
- if self.selected > 0:
1261
- ccol[self.selected] = np.array([255, 255, 255])
1262
- col2mask = ccol[cellpix]
1263
- if self.masksOn:
1264
- col2mask = np.concatenate(
1265
- (col2mask, self.opacity * (cellpix[:, np.newaxis] > 0)), axis=-1)
1266
- else:
1267
- col2mask = np.concatenate((col2mask, 0 * (cellpix[:, np.newaxis] > 0)),
1268
- axis=-1)
1269
- self.layerz[stroke[:, 1], stroke[:, 2], :] = col2mask
1270
- if self.outlinesOn:
1271
- self.layerz[stroke[outpix, 1], stroke[outpix,
1272
- 2]] = np.array(self.outcolor)
1273
- if delete_points:
1274
- del self.current_point_set[stroke_ind]
1275
- self.update_layer()
1276
-
1277
- del self.strokes[stroke_ind]
1278
-
1279
- def plot_clicked(self, event):
1280
- if event.button()==QtCore.Qt.LeftButton \
1281
- and not event.modifiers() & (QtCore.Qt.ShiftModifier | QtCore.Qt.AltModifier)\
1282
- and not self.removing_region:
1283
- if event.double():
1284
- try:
1285
- self.p0.setYRange(0, self.Ly + self.pr)
1286
- except:
1287
- self.p0.setYRange(0, self.Ly)
1288
- self.p0.setXRange(0, self.Lx)
1289
-
1290
- def cancel_remove_multiple(self):
1291
- self.clear_multi_selected_cells()
1292
- self.done_remove_multiple_cells()
1293
-
1294
- def clear_multi_selected_cells(self):
1295
- # unselect all previously selected cells:
1296
- for idx in self.removing_cells_list:
1297
- self.unselect_cell_multi(idx)
1298
- self.removing_cells_list.clear()
1299
-
1300
- def add_roi(self, roi):
1301
- self.p0.addItem(roi)
1302
- self.remove_roi_obj = roi
1303
-
1304
- def remove_roi(self, roi):
1305
- self.clear_multi_selected_cells()
1306
- assert roi == self.remove_roi_obj
1307
- self.remove_roi_obj = None
1308
- self.p0.removeItem(roi)
1309
- self.removing_region = False
1310
-
1311
- def roi_changed(self, roi):
1312
- # find the overlapping cells and make them selected
1313
- pos = roi.pos()
1314
- size = roi.size()
1315
- x0 = int(pos.x())
1316
- y0 = int(pos.y())
1317
- x1 = int(pos.x() + size.x())
1318
- y1 = int(pos.y() + size.y())
1319
- if x0 < 0:
1320
- x0 = 0
1321
- if y0 < 0:
1322
- y0 = 0
1323
- if x1 > self.Lx:
1324
- x1 = self.Lx
1325
- if y1 > self.Ly:
1326
- y1 = self.Ly
1327
-
1328
- # find cells in that region
1329
- cell_idxs = np.unique(self.cellpix[self.currentZ, y0:y1, x0:x1])
1330
- cell_idxs = np.trim_zeros(cell_idxs)
1331
- # deselect cells not in region by deselecting all and then selecting the ones in the region
1332
- self.clear_multi_selected_cells()
1333
-
1334
- for idx in cell_idxs:
1335
- self.select_cell_multi(idx)
1336
- self.removing_cells_list.append(idx)
1337
-
1338
- self.update_layer()
1339
-
1340
- def mouse_moved(self, pos):
1341
- items = self.win.scene().items(pos)
1342
-
1343
- def color_choose(self):
1344
- self.color = self.RGBDropDown.currentIndex()
1345
- self.view = 0
1346
- self.ViewDropDown.setCurrentIndex(self.view)
1347
- self.update_plot()
1348
-
1349
- def update_plot(self):
1350
- self.view = self.ViewDropDown.currentIndex()
1351
- self.Ly, self.Lx, _ = self.stack[self.currentZ].shape
1352
-
1353
- if self.view == 0 or self.view == self.ViewDropDown.count() - 1:
1354
- image = self.stack[
1355
- self.currentZ] if self.view == 0 else self.stack_filtered[self.currentZ]
1356
- if self.color == 0:
1357
- self.img.setImage(image, autoLevels=False, lut=None)
1358
- if self.nchan > 1:
1359
- levels = np.array([
1360
- self.saturation[0][self.currentZ],
1361
- self.saturation[1][self.currentZ],
1362
- self.saturation[2][self.currentZ]
1363
- ])
1364
- self.img.setLevels(levels)
1365
- else:
1366
- self.img.setLevels(self.saturation[0][self.currentZ])
1367
- elif self.color > 0 and self.color < 4:
1368
- if self.nchan > 1:
1369
- image = image[:, :, self.color - 1]
1370
- self.img.setImage(image, autoLevels=False, lut=self.cmap[self.color])
1371
- if self.nchan > 1:
1372
- self.img.setLevels(self.saturation[self.color - 1][self.currentZ])
1373
- else:
1374
- self.img.setLevels(self.saturation[0][self.currentZ])
1375
- elif self.color == 4:
1376
- if self.nchan > 1:
1377
- image = image.mean(axis=-1)
1378
- self.img.setImage(image, autoLevels=False, lut=None)
1379
- self.img.setLevels(self.saturation[0][self.currentZ])
1380
- elif self.color == 5:
1381
- if self.nchan > 1:
1382
- image = image.mean(axis=-1)
1383
- self.img.setImage(image, autoLevels=False, lut=self.cmap[0])
1384
- self.img.setLevels(self.saturation[0][self.currentZ])
1385
- else:
1386
- image = np.zeros((self.Ly, self.Lx), np.uint8)
1387
- if len(self.flows) >= self.view - 1 and len(self.flows[self.view - 1]) > 0:
1388
- image = self.flows[self.view - 1][self.currentZ]
1389
- if self.view > 1:
1390
- self.img.setImage(image, autoLevels=False, lut=self.bwr)
1391
- else:
1392
- self.img.setImage(image, autoLevels=False, lut=None)
1393
- self.img.setLevels([0.0, 255.0])
1394
-
1395
- for r in range(3):
1396
- self.sliders[r].setValue([
1397
- self.saturation[r][self.currentZ][0],
1398
- self.saturation[r][self.currentZ][1]
1399
- ])
1400
- self.win.show()
1401
- self.show()
1402
-
1403
-
1404
- def update_layer(self):
1405
- if self.masksOn or self.outlinesOn:
1406
- self.layer.setImage(self.layerz, autoLevels=False)
1407
- self.win.show()
1408
- self.show()
1409
-
1410
-
1411
- def add_set(self):
1412
- if len(self.current_point_set) > 0:
1413
- while len(self.strokes) > 0:
1414
- self.remove_stroke(delete_points=False)
1415
- if len(self.current_point_set[0]) > 8:
1416
- color = self.colormap[self.ncells.get(), :3]
1417
- median = self.add_mask(points=self.current_point_set, color=color)
1418
- if median is not None:
1419
- self.removed_cell = []
1420
- self.toggle_mask_ops()
1421
- self.cellcolors = np.append(self.cellcolors, color[np.newaxis, :],
1422
- axis=0)
1423
- self.ncells += 1
1424
- self.ismanual = np.append(self.ismanual, True)
1425
- if self.NZ == 1:
1426
- # only save after each cell if single image
1427
- io._save_sets_with_check(self)
1428
- else:
1429
- print("GUI_ERROR: cell too small, not drawn")
1430
- self.current_stroke = []
1431
- self.strokes = []
1432
- self.current_point_set = []
1433
- self.update_layer()
1434
-
1435
- def add_mask(self, points=None, color=(100, 200, 50), dense=True):
1436
- # points is list of strokes
1437
- points_all = np.concatenate(points, axis=0)
1438
-
1439
- # loop over z values
1440
- median = []
1441
- zdraw = np.unique(points_all[:, 0])
1442
- z = 0
1443
- ars, acs, vrs, vcs = np.zeros(0, "int"), np.zeros(0, "int"), np.zeros(
1444
- 0, "int"), np.zeros(0, "int")
1445
- for stroke in points:
1446
- stroke = np.concatenate(stroke, axis=0).reshape(-1, 4)
1447
- vr = stroke[:, 1]
1448
- vc = stroke[:, 2]
1449
- # get points inside drawn points
1450
- mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), np.uint8)
1451
- pts = np.stack((vc - vc.min() + 2, vr - vr.min() + 2),
1452
- axis=-1)[:, np.newaxis, :]
1453
- mask = cv2.fillPoly(mask, [pts], (255, 0, 0))
1454
- ar, ac = np.nonzero(mask)
1455
- ar, ac = ar + vr.min() - 2, ac + vc.min() - 2
1456
- # get dense outline
1457
- contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
1458
- pvc, pvr = contours[-2][0][:,0].T
1459
- vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2
1460
- # concatenate all points
1461
- ar, ac = np.hstack((np.vstack((vr, vc)), np.vstack((ar, ac))))
1462
- # if these pixels are overlapping with another cell, reassign them
1463
- ioverlap = self.cellpix[z][ar, ac] > 0
1464
- if (~ioverlap).sum() < 10:
1465
- print("GUI_ERROR: cell < 10 pixels without overlaps, not drawn")
1466
- return None
1467
- elif ioverlap.sum() > 0:
1468
- ar, ac = ar[~ioverlap], ac[~ioverlap]
1469
- # compute outline of new mask
1470
- mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), np.uint8)
1471
- mask[ar - vr.min() + 2, ac - vc.min() + 2] = 1
1472
- contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
1473
- cv2.CHAIN_APPROX_NONE)
1474
- pvc, pvr = contours[-2][0][:,0].T
1475
- vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2
1476
- ars = np.concatenate((ars, ar), axis=0)
1477
- acs = np.concatenate((acs, ac), axis=0)
1478
- vrs = np.concatenate((vrs, vr), axis=0)
1479
- vcs = np.concatenate((vcs, vc), axis=0)
1480
-
1481
- self.draw_mask(z, ars, acs, vrs, vcs, color)
1482
- median.append(np.array([np.median(ars), np.median(acs)]))
1483
-
1484
- self.zdraw.append(zdraw)
1485
- d = datetime.datetime.now()
1486
- self.track_changes.append(
1487
- [d.strftime("%m/%d/%Y, %H:%M:%S"), "added mask", [ar, ac]])
1488
- return median
1489
-
1490
- def draw_mask(self, z, ar, ac, vr, vc, color, idx=None):
1491
- """ draw single mask using outlines and area """
1492
- if idx is None:
1493
- idx = self.ncells + 1
1494
- self.cellpix[z, vr, vc] = idx
1495
- self.cellpix[z, ar, ac] = idx
1496
- self.outpix[z, vr, vc] = idx
1497
- if self.restore and "upsample" in self.restore:
1498
- if self.resize:
1499
- self.cellpix_resize[z, vr, vc] = idx
1500
- self.cellpix_resize[z, ar, ac] = idx
1501
- self.outpix_resize[z, vr, vc] = idx
1502
- self.cellpix_orig[z, (vr / self.ratio).astype(int),
1503
- (vc / self.ratio).astype(int)] = idx
1504
- self.cellpix_orig[z, (ar / self.ratio).astype(int),
1505
- (ac / self.ratio).astype(int)] = idx
1506
- self.outpix_orig[z, (vr / self.ratio).astype(int),
1507
- (vc / self.ratio).astype(int)] = idx
1508
- else:
1509
- self.cellpix_orig[z, vr, vc] = idx
1510
- self.cellpix_orig[z, ar, ac] = idx
1511
- self.outpix_orig[z, vr, vc] = idx
1512
-
1513
- # get upsampled mask
1514
- vrr = (vr.copy() * self.ratio).astype(int)
1515
- vcr = (vc.copy() * self.ratio).astype(int)
1516
- mask = np.zeros((np.ptp(vrr) + 4, np.ptp(vcr) + 4), np.uint8)
1517
- pts = np.stack((vcr - vcr.min() + 2, vrr - vrr.min() + 2),
1518
- axis=-1)[:, np.newaxis, :]
1519
- mask = cv2.fillPoly(mask, [pts], (255, 0, 0))
1520
- arr, acr = np.nonzero(mask)
1521
- arr, acr = arr + vrr.min() - 2, acr + vcr.min() - 2
1522
- # get dense outline
1523
- contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
1524
- cv2.CHAIN_APPROX_NONE)
1525
- pvc, pvr = contours[-2][0].squeeze().T
1526
- vrr, vcr = pvr + vrr.min() - 2, pvc + vcr.min() - 2
1527
- # concatenate all points
1528
- arr, acr = np.hstack((np.vstack((vrr, vcr)), np.vstack((arr, acr))))
1529
- self.cellpix_resize[z, vrr, vcr] = idx
1530
- self.cellpix_resize[z, arr, acr] = idx
1531
- self.outpix_resize[z, vrr, vcr] = idx
1532
-
1533
- if z == self.currentZ:
1534
- self.layerz[ar, ac, :3] = color
1535
- if self.masksOn:
1536
- self.layerz[ar, ac, -1] = self.opacity
1537
- if self.outlinesOn:
1538
- self.layerz[vr, vc] = np.array(self.outcolor)
1539
-
1540
- def compute_scale(self):
1541
- # get diameter from gui
1542
- diameter = self.segmentation_settings.diameter
1543
- if not diameter:
1544
- diameter = 30
1545
-
1546
- self.pr = int(diameter)
1547
- self.radii_padding = int(self.pr * 1.25)
1548
- self.radii = np.zeros((self.Ly + self.radii_padding, self.Lx, 4), np.uint8)
1549
- yy, xx = disk([self.Ly + self.radii_padding / 2 - 1, self.pr / 2 + 1],
1550
- self.pr / 2, self.Ly + self.radii_padding, self.Lx)
1551
- # rgb(150,50,150)
1552
- self.radii[yy, xx, 0] = 150
1553
- self.radii[yy, xx, 1] = 50
1554
- self.radii[yy, xx, 2] = 150
1555
- self.radii[yy, xx, 3] = 255
1556
- self.p0.setYRange(0, self.Ly + self.radii_padding)
1557
- self.p0.setXRange(0, self.Lx)
1558
-
1559
- def update_scale(self):
1560
- self.compute_scale()
1561
- self.scale.setImage(self.radii, autoLevels=False)
1562
- self.scale.setLevels([0.0, 255.0])
1563
- self.win.show()
1564
- self.show()
1565
-
1566
-
1567
- def draw_layer(self):
1568
- if self.resize:
1569
- self.Ly, self.Lx = self.Lyr, self.Lxr
1570
- else:
1571
- self.Ly, self.Lx = self.Ly0, self.Lx0
1572
-
1573
- if self.masksOn or self.outlinesOn:
1574
- if self.restore and "upsample" in self.restore:
1575
- if self.resize:
1576
- self.cellpix = self.cellpix_resize.copy()
1577
- self.outpix = self.outpix_resize.copy()
1578
- else:
1579
- self.cellpix = self.cellpix_orig.copy()
1580
- self.outpix = self.outpix_orig.copy()
1581
-
1582
- self.layerz = np.zeros((self.Ly, self.Lx, 4), np.uint8)
1583
- if self.masksOn:
1584
- self.layerz[..., :3] = self.cellcolors[self.cellpix[self.currentZ], :]
1585
- self.layerz[..., 3] = self.opacity * (self.cellpix[self.currentZ]
1586
- > 0).astype(np.uint8)
1587
- if self.selected > 0:
1588
- self.layerz[self.cellpix[self.currentZ] == self.selected] = np.array(
1589
- [255, 255, 255, self.opacity])
1590
- cZ = self.currentZ
1591
- stroke_z = np.array([s[0][0] for s in self.strokes])
1592
- inZ = np.nonzero(stroke_z == cZ)[0]
1593
- if len(inZ) > 0:
1594
- for i in inZ:
1595
- stroke = np.array(self.strokes[i])
1596
- self.layerz[stroke[:, 1], stroke[:,
1597
- 2]] = np.array([255, 0, 255, 100])
1598
- else:
1599
- self.layerz[..., 3] = 0
1600
-
1601
- if self.outlinesOn:
1602
- self.layerz[self.outpix[self.currentZ] > 0] = np.array(
1603
- self.outcolor).astype(np.uint8)
1604
-
1605
-
1606
- def set_normalize_params(self, normalize_params):
1607
- from cellpose.models import normalize_default
1608
- if self.restore != "filter":
1609
- keys = list(normalize_params.keys()).copy()
1610
- for key in keys:
1611
- if key != "percentile":
1612
- normalize_params[key] = normalize_default[key]
1613
- normalize_params = {**normalize_default, **normalize_params}
1614
- out = self.check_filter_params(normalize_params["sharpen_radius"],
1615
- normalize_params["smooth_radius"],
1616
- normalize_params["tile_norm_blocksize"],
1617
- normalize_params["tile_norm_smooth3D"],
1618
- normalize_params["norm3D"],
1619
- normalize_params["invert"])
1620
-
1621
-
1622
- def check_filter_params(self, sharpen, smooth, tile_norm, smooth3D, norm3D, invert):
1623
- tile_norm = 0 if tile_norm < 0 else tile_norm
1624
- sharpen = 0 if sharpen < 0 else sharpen
1625
- smooth = 0 if smooth < 0 else smooth
1626
- smooth3D = 0 if smooth3D < 0 else smooth3D
1627
- norm3D = bool(norm3D)
1628
- invert = bool(invert)
1629
- if tile_norm > self.Ly and tile_norm > self.Lx:
1630
- print(
1631
- "GUI_ERROR: tile size (tile_norm) bigger than both image dimensions, disabling"
1632
- )
1633
- tile_norm = 0
1634
- self.filt_edits[0].setText(str(sharpen))
1635
- self.filt_edits[1].setText(str(smooth))
1636
- self.filt_edits[2].setText(str(tile_norm))
1637
- self.filt_edits[3].setText(str(smooth3D))
1638
- self.norm3D_cb.setChecked(norm3D)
1639
- return sharpen, smooth, tile_norm, smooth3D, norm3D, invert
1640
-
1641
- def get_normalize_params(self):
1642
- percentile = [
1643
- self.segmentation_settings.low_percentile,
1644
- self.segmentation_settings.high_percentile,
1645
- ]
1646
- normalize_params = {"percentile": percentile}
1647
- norm3D = self.norm3D_cb.isChecked()
1648
- normalize_params["norm3D"] = norm3D
1649
- sharpen = float(self.filt_edits[0].text())
1650
- smooth = float(self.filt_edits[1].text())
1651
- tile_norm = float(self.filt_edits[2].text())
1652
- smooth3D = float(self.filt_edits[3].text())
1653
- invert = False
1654
- out = self.check_filter_params(sharpen, smooth, tile_norm, smooth3D, norm3D,
1655
- invert)
1656
- sharpen, smooth, tile_norm, smooth3D, norm3D, invert = out
1657
- normalize_params["sharpen_radius"] = sharpen
1658
- normalize_params["smooth_radius"] = smooth
1659
- normalize_params["tile_norm_blocksize"] = tile_norm
1660
- normalize_params["tile_norm_smooth3D"] = smooth3D
1661
- normalize_params["invert"] = invert
1662
-
1663
- from cellpose.models import normalize_default
1664
- normalize_params = {**normalize_default, **normalize_params}
1665
-
1666
- return normalize_params
1667
-
1668
- def compute_saturation_if_checked(self):
1669
- if self.autobtn.isChecked():
1670
- self.compute_saturation()
1671
-
1672
- def compute_saturation(self, return_img=False):
1673
- norm = self.get_normalize_params()
1674
- print(norm)
1675
- sharpen, smooth = norm["sharpen_radius"], norm["smooth_radius"]
1676
- percentile = norm["percentile"]
1677
- tile_norm = norm["tile_norm_blocksize"]
1678
- invert = norm["invert"]
1679
- norm3D = norm["norm3D"]
1680
- smooth3D = norm["tile_norm_smooth3D"]
1681
- tile_norm = norm["tile_norm_blocksize"]
1682
-
1683
- if sharpen > 0 or smooth > 0 or tile_norm > 0:
1684
- img_norm = self.stack.copy()
1685
- else:
1686
- img_norm = self.stack
1687
-
1688
- if sharpen > 0 or smooth > 0 or tile_norm > 0:
1689
- self.restore = "filter"
1690
- print(
1691
- "GUI_INFO: computing filtered image because sharpen > 0 or tile_norm > 0"
1692
- )
1693
- print(
1694
- "GUI_WARNING: will use memory to create filtered image -- make sure to have RAM for this"
1695
- )
1696
- img_norm = self.stack.copy()
1697
- if sharpen > 0 or smooth > 0:
1698
- img_norm = smooth_sharpen_img(self.stack, sharpen_radius=sharpen,
1699
- smooth_radius=smooth)
1700
-
1701
- if tile_norm > 0:
1702
- img_norm = normalize99_tile(img_norm, blocksize=tile_norm,
1703
- lower=percentile[0], upper=percentile[1],
1704
- smooth3D=smooth3D, norm3D=norm3D)
1705
- # convert to 0->255
1706
- img_norm_min = img_norm.min()
1707
- img_norm_max = img_norm.max()
1708
- for c in range(img_norm.shape[-1]):
1709
- if np.ptp(img_norm[..., c]) > 1e-3:
1710
- img_norm[..., c] -= img_norm_min
1711
- img_norm[..., c] /= (img_norm_max - img_norm_min)
1712
- img_norm *= 255
1713
- self.stack_filtered = img_norm
1714
- self.ViewDropDown.model().item(self.ViewDropDown.count() -
1715
- 1).setEnabled(True)
1716
- self.ViewDropDown.setCurrentIndex(self.ViewDropDown.count() - 1)
1717
- else:
1718
- img_norm = self.stack if self.restore is None or self.restore == "filter" else self.stack_filtered
1719
-
1720
- if self.autobtn.isChecked():
1721
- self.saturation = []
1722
- for c in range(img_norm.shape[-1]):
1723
- self.saturation.append([])
1724
- if np.ptp(img_norm[..., c]) > 1e-3:
1725
- if norm3D:
1726
- x01 = np.percentile(img_norm[..., c], percentile[0])
1727
- x99 = np.percentile(img_norm[..., c], percentile[1])
1728
- if invert:
1729
- x01i = 255. - x99
1730
- x99i = 255. - x01
1731
- x01, x99 = x01i, x99i
1732
- for n in range(self.NZ):
1733
- self.saturation[-1].append([x01, x99])
1734
- else:
1735
- for z in range(self.NZ):
1736
- if self.NZ > 1:
1737
- x01 = np.percentile(img_norm[z, :, :, c], percentile[0])
1738
- x99 = np.percentile(img_norm[z, :, :, c], percentile[1])
1739
- else:
1740
- x01 = np.percentile(img_norm[..., c], percentile[0])
1741
- x99 = np.percentile(img_norm[..., c], percentile[1])
1742
- if invert:
1743
- x01i = 255. - x99
1744
- x99i = 255. - x01
1745
- x01, x99 = x01i, x99i
1746
- self.saturation[-1].append([x01, x99])
1747
- else:
1748
- for n in range(self.NZ):
1749
- self.saturation[-1].append([0, 255.])
1750
- print(self.saturation[2][self.currentZ])
1751
-
1752
- if img_norm.shape[-1] == 1:
1753
- self.saturation.append(self.saturation[0])
1754
- self.saturation.append(self.saturation[0])
1755
-
1756
- # self.autobtn.setChecked(True)
1757
- self.update_plot()
1758
-
1759
-
1760
- def get_model_path(self, custom=False):
1761
- if custom:
1762
- self.current_model = self.ModelChooseC.currentText()
1763
- self.current_model_path = os.fspath(
1764
- models.MODEL_DIR.joinpath(self.current_model))
1765
- else:
1766
- self.current_model = "cpsam"
1767
- self.current_model_path = models.model_path(self.current_model)
1768
-
1769
- def initialize_model(self, model_name=None, custom=False):
1770
- if model_name is None or custom:
1771
- self.get_model_path(custom=custom)
1772
- if not os.path.exists(self.current_model_path):
1773
- raise ValueError("need to specify model (use dropdown)")
1774
-
1775
- if model_name is None or not isinstance(model_name, str):
1776
- self.model = models.CellposeModel(gpu=self.useGPU.isChecked(),
1777
- pretrained_model=self.current_model_path)
1778
- else:
1779
- self.current_model = model_name
1780
- self.current_model_path = os.fspath(
1781
- models.MODEL_DIR.joinpath(self.current_model))
1782
-
1783
- self.model = models.CellposeModel(gpu=self.useGPU.isChecked(),
1784
- pretrained_model=self.current_model)
1785
-
1786
- def add_model(self):
1787
- io._add_model(self)
1788
- return
1789
-
1790
- def remove_model(self):
1791
- io._remove_model(self)
1792
- return
1793
-
1794
- def new_model(self):
1795
- if self.NZ != 1:
1796
- print("ERROR: cannot train model on 3D data")
1797
- return
1798
-
1799
- # train model
1800
- image_names = self.get_files()[0]
1801
- self.train_data, self.train_labels, self.train_files, restore, normalize_params = io._get_train_set(
1802
- image_names)
1803
- TW = guiparts.TrainWindow(self, models.MODEL_NAMES)
1804
- train = TW.exec_()
1805
- if train:
1806
- self.logger.info(
1807
- f"training with {[os.path.split(f)[1] for f in self.train_files]}")
1808
- self.train_model(restore=restore, normalize_params=normalize_params)
1809
- else:
1810
- print("GUI_INFO: training cancelled")
1811
-
1812
- def train_model(self, restore=None, normalize_params=None):
1813
- from cellpose.models import normalize_default
1814
- if normalize_params is None:
1815
- normalize_params = copy.deepcopy(normalize_default)
1816
- model_type = models.MODEL_NAMES[self.training_params["model_index"]]
1817
- self.logger.info(f"training new model starting at model {model_type}")
1818
- self.current_model = model_type
1819
-
1820
- self.model = models.CellposeModel(gpu=self.useGPU.isChecked(),
1821
- model_type=model_type)
1822
- save_path = os.path.dirname(self.filename)
1823
-
1824
- print("GUI_INFO: name of new model: " + self.training_params["model_name"])
1825
- self.new_model_path, train_losses = train.train_seg(
1826
- self.model.net, train_data=self.train_data, train_labels=self.train_labels,
1827
- normalize=normalize_params, min_train_masks=0,
1828
- save_path=save_path, nimg_per_epoch=max(2, len(self.train_data)),
1829
- learning_rate=self.training_params["learning_rate"],
1830
- weight_decay=self.training_params["weight_decay"],
1831
- n_epochs=self.training_params["n_epochs"],
1832
- model_name=self.training_params["model_name"])[:2]
1833
- # save train losses
1834
- np.save(str(self.new_model_path) + "_train_losses.npy", train_losses)
1835
- # run model on next image
1836
- io._add_model(self, self.new_model_path)
1837
- diam_labels = self.model.net.diam_labels.item() #.copy()
1838
- self.new_model_ind = len(self.model_strings)
1839
- self.autorun = True
1840
- self.clear_all()
1841
- self.restore = restore
1842
- self.set_normalize_params(normalize_params)
1843
- self.get_next_image(load_seg=False)
1844
-
1845
- self.compute_segmentation(custom=True)
1846
- self.logger.info(
1847
- f"!!! computed masks for {os.path.split(self.filename)[1]} from new model !!!"
1848
- )
1849
-
1850
-
1851
- def compute_cprob(self):
1852
- if self.recompute_masks:
1853
- flow_threshold = self.segmentation_settings.flow_threshold
1854
- cellprob_threshold = self.segmentation_settings.cellprob_threshold
1855
- niter = self.segmentation_settings.niter
1856
- min_size = int(self.min_size.text()) if not isinstance(
1857
- self.min_size, int) else self.min_size
1858
-
1859
- self.logger.info(
1860
- "computing masks with cell prob=%0.3f, flow error threshold=%0.3f" %
1861
- (cellprob_threshold, flow_threshold))
1862
-
1863
- try:
1864
- dP = self.flows[2].squeeze()
1865
- cellprob = self.flows[3].squeeze()
1866
- except IndexError:
1867
- self.logger.error("Flows don't exist, try running model again.")
1868
- return
1869
-
1870
- maski = dynamics.resize_and_compute_masks(
1871
- dP=dP,
1872
- cellprob=cellprob,
1873
- niter=niter,
1874
- do_3D=self.load_3D,
1875
- min_size=min_size,
1876
- # max_size_fraction=min_size_fraction, # Leave as default
1877
- cellprob_threshold=cellprob_threshold,
1878
- flow_threshold=flow_threshold)
1879
-
1880
- self.masksOn = True
1881
- if not self.OCheckBox.isChecked():
1882
- self.MCheckBox.setChecked(True)
1883
- if maski.ndim < 3:
1884
- maski = maski[np.newaxis, ...]
1885
- self.logger.info("%d cells found" % (len(np.unique(maski)[1:])))
1886
- io._masks_to_gui(self, maski, outlines=None)
1887
- self.show()
1888
-
1889
-
1890
- def compute_segmentation(self, custom=False, model_name=None, load_model=True):
1891
- self.progress.setValue(0)
1892
- try:
1893
- tic = time.time()
1894
- self.clear_all()
1895
- self.flows = [[], [], []]
1896
- if load_model:
1897
- self.initialize_model(model_name=model_name, custom=custom)
1898
- self.progress.setValue(10)
1899
- do_3D = self.load_3D
1900
- stitch_threshold = float(self.stitch_threshold.text()) if not isinstance(
1901
- self.stitch_threshold, float) else self.stitch_threshold
1902
- anisotropy = float(self.anisotropy.text()) if not isinstance(
1903
- self.anisotropy, float) else self.anisotropy
1904
- flow3D_smooth = float(self.flow3D_smooth.text()) if not isinstance(
1905
- self.flow3D_smooth, float) else self.flow3D_smooth
1906
- min_size = int(self.min_size.text()) if not isinstance(
1907
- self.min_size, int) else self.min_size
1908
-
1909
- do_3D = False if stitch_threshold > 0. else do_3D
1910
-
1911
- if self.restore == "filter":
1912
- data = self.stack_filtered.copy().squeeze()
1913
- else:
1914
- data = self.stack.copy().squeeze()
1915
-
1916
- flow_threshold = self.segmentation_settings.flow_threshold
1917
- cellprob_threshold = self.segmentation_settings.cellprob_threshold
1918
- diameter = self.segmentation_settings.diameter
1919
- niter = self.segmentation_settings.niter
1920
-
1921
- normalize_params = self.get_normalize_params()
1922
- print(normalize_params)
1923
- try:
1924
- masks, flows = self.model.eval(
1925
- data,
1926
- diameter=diameter,
1927
- cellprob_threshold=cellprob_threshold,
1928
- flow_threshold=flow_threshold, do_3D=do_3D, niter=niter,
1929
- normalize=normalize_params, stitch_threshold=stitch_threshold,
1930
- anisotropy=anisotropy, flow3D_smooth=flow3D_smooth,
1931
- min_size=min_size, channel_axis=-1,
1932
- progress=self.progress, z_axis=0 if self.NZ > 1 else None)[:2]
1933
- except Exception as e:
1934
- print("NET ERROR: %s" % e)
1935
- self.progress.setValue(0)
1936
- return
1937
-
1938
- self.progress.setValue(75)
1939
-
1940
- # convert flows to uint8 and resize to original image size
1941
- flows_new = []
1942
- flows_new.append(flows[0].copy()) # RGB flow
1943
- flows_new.append((np.clip(normalize99(flows[2].copy()), 0, 1) *
1944
- 255).astype("uint8")) # cellprob
1945
- flows_new.append(flows[1].copy()) # XY flows
1946
- flows_new.append(flows[2].copy()) # original cellprob
1947
-
1948
- if self.load_3D:
1949
- if stitch_threshold == 0.:
1950
- flows_new.append((flows[1][0] / 10 * 127 + 127).astype("uint8"))
1951
- else:
1952
- flows_new.append(np.zeros(flows[1][0].shape, dtype="uint8"))
1953
-
1954
- if not self.load_3D:
1955
- if self.restore and "upsample" in self.restore:
1956
- self.Ly, self.Lx = self.Lyr, self.Lxr
1957
-
1958
- if flows_new[0].shape[-3:-1] != (self.Ly, self.Lx):
1959
- self.flows = []
1960
- for j in range(len(flows_new)):
1961
- self.flows.append(
1962
- resize_image(flows_new[j], Ly=self.Ly, Lx=self.Lx,
1963
- interpolation=cv2.INTER_NEAREST))
1964
- else:
1965
- self.flows = flows_new
1966
- else:
1967
- self.flows = []
1968
- Lz, Ly, Lx = self.NZ, self.Ly, self.Lx
1969
- Lz0, Ly0, Lx0 = flows_new[0].shape[:3]
1970
- print("GUI_INFO: resizing flows to original image size")
1971
- for j in range(len(flows_new)):
1972
- flow0 = flows_new[j]
1973
- if Ly0 != Ly:
1974
- flow0 = resize_image(flow0, Ly=Ly, Lx=Lx,
1975
- no_channels=flow0.ndim==3,
1976
- interpolation=cv2.INTER_NEAREST)
1977
- if Lz0 != Lz:
1978
- flow0 = np.swapaxes(resize_image(np.swapaxes(flow0, 0, 1),
1979
- Ly=Lz, Lx=Lx,
1980
- no_channels=flow0.ndim==3,
1981
- interpolation=cv2.INTER_NEAREST), 0, 1)
1982
- self.flows.append(flow0)
1983
-
1984
- # add first axis
1985
- if self.NZ == 1:
1986
- masks = masks[np.newaxis, ...]
1987
- self.flows = [
1988
- self.flows[n][np.newaxis, ...] for n in range(len(self.flows))
1989
- ]
1990
-
1991
- self.logger.info("%d cells found with model in %0.3f sec" %
1992
- (len(np.unique(masks)[1:]), time.time() - tic))
1993
- self.progress.setValue(80)
1994
- z = 0
1995
-
1996
- io._masks_to_gui(self, masks, outlines=None)
1997
- self.masksOn = True
1998
- self.MCheckBox.setChecked(True)
1999
- self.progress.setValue(100)
2000
- if self.restore != "filter" and self.restore is not None and self.autobtn.isChecked():
2001
- self.compute_saturation()
2002
- if not do_3D and not stitch_threshold > 0:
2003
- self.recompute_masks = True
2004
- else:
2005
- self.recompute_masks = False
2006
- except Exception as e:
2007
- print("ERROR: %s" % e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/seg_post_model/cellpose/gui/gui3d.py DELETED
@@ -1,667 +0,0 @@
1
- """
2
- Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer, Michael Rariden and Marius Pachitariu.
3
- """
4
-
5
- import sys, pathlib, warnings
6
-
7
- from qtpy import QtGui, QtCore
8
- from qtpy.QtWidgets import QApplication, QScrollBar, QCheckBox, QLabel, QLineEdit
9
- import pyqtgraph as pg
10
-
11
- import numpy as np
12
- from scipy.stats import mode
13
- import cv2
14
-
15
- from . import guiparts, io
16
- from ..utils import download_url_to_file, masks_to_outlines
17
- from .gui import MainW
18
-
19
- try:
20
- import matplotlib.pyplot as plt
21
- MATPLOTLIB = True
22
- except:
23
- MATPLOTLIB = False
24
-
25
-
26
- def avg3d(C):
27
- """ smooth value of c across nearby points
28
- (c is center of grid directly below point)
29
- b -- a -- b
30
- a -- c -- a
31
- b -- a -- b
32
- """
33
- Ly, Lx = C.shape
34
- # pad T by 2
35
- T = np.zeros((Ly + 2, Lx + 2), "float32")
36
- M = np.zeros((Ly, Lx), "float32")
37
- T[1:-1, 1:-1] = C.copy()
38
- y, x = np.meshgrid(np.arange(0, Ly, 1, int), np.arange(0, Lx, 1, int),
39
- indexing="ij")
40
- y += 1
41
- x += 1
42
- a = 1. / 2 #/(z**2 + 1)**0.5
43
- b = 1. / (1 + 2**0.5) #(z**2 + 2)**0.5
44
- c = 1.
45
- M = (b * T[y - 1, x - 1] + a * T[y - 1, x] + b * T[y - 1, x + 1] + a * T[y, x - 1] +
46
- c * T[y, x] + a * T[y, x + 1] + b * T[y + 1, x - 1] + a * T[y + 1, x] +
47
- b * T[y + 1, x + 1])
48
- M /= 4 * a + 4 * b + c
49
- return M
50
-
51
-
52
- def interpZ(mask, zdraw):
53
- """ find nearby planes and average their values using grid of points
54
- zfill is in ascending order
55
- """
56
- ifill = np.ones(mask.shape[0], "bool")
57
- zall = np.arange(0, mask.shape[0], 1, int)
58
- ifill[zdraw] = False
59
- zfill = zall[ifill]
60
- zlower = zdraw[np.searchsorted(zdraw, zfill, side="left") - 1]
61
- zupper = zdraw[np.searchsorted(zdraw, zfill, side="right")]
62
- for k, z in enumerate(zfill):
63
- Z = zupper[k] - zlower[k]
64
- zl = (z - zlower[k]) / Z
65
- plower = avg3d(mask[zlower[k]]) * (1 - zl)
66
- pupper = avg3d(mask[zupper[k]]) * zl
67
- mask[z] = (plower + pupper) > 0.33
68
- return mask, zfill
69
-
70
-
71
- def run(image=None):
72
- from ..io import logger_setup
73
- logger, log_file = logger_setup()
74
- # Always start by initializing Qt (only once per application)
75
- warnings.filterwarnings("ignore")
76
- app = QApplication(sys.argv)
77
- icon_path = pathlib.Path.home().joinpath(".cellpose", "logo.png")
78
- guip_path = pathlib.Path.home().joinpath(".cellpose", "cellpose_gui.png")
79
- style_path = pathlib.Path.home().joinpath(".cellpose", "style_choice.npy")
80
- if not icon_path.is_file():
81
- cp_dir = pathlib.Path.home().joinpath(".cellpose")
82
- cp_dir.mkdir(exist_ok=True)
83
- print("downloading logo")
84
- download_url_to_file(
85
- "https://www.cellpose.org/static/images/cellpose_transparent.png",
86
- icon_path, progress=True)
87
- if not guip_path.is_file():
88
- print("downloading help window image")
89
- download_url_to_file("https://www.cellpose.org/static/images/cellpose_gui.png",
90
- guip_path, progress=True)
91
- icon_path = str(icon_path.resolve())
92
- app_icon = QtGui.QIcon()
93
- app_icon.addFile(icon_path, QtCore.QSize(16, 16))
94
- app_icon.addFile(icon_path, QtCore.QSize(24, 24))
95
- app_icon.addFile(icon_path, QtCore.QSize(32, 32))
96
- app_icon.addFile(icon_path, QtCore.QSize(48, 48))
97
- app_icon.addFile(icon_path, QtCore.QSize(64, 64))
98
- app_icon.addFile(icon_path, QtCore.QSize(256, 256))
99
- app.setWindowIcon(app_icon)
100
- app.setStyle("Fusion")
101
- app.setPalette(guiparts.DarkPalette())
102
- MainW_3d(image=image, logger=logger)
103
- ret = app.exec_()
104
- sys.exit(ret)
105
-
106
-
107
- class MainW_3d(MainW):
108
-
109
- def __init__(self, image=None, logger=None):
110
- # MainW init
111
- MainW.__init__(self, image=image, logger=logger)
112
-
113
- # add gradZ view
114
- self.ViewDropDown.insertItem(3, "gradZ")
115
-
116
- # turn off single stroke
117
- self.SCheckBox.setChecked(False)
118
-
119
- ### add orthoviews and z-bar
120
- # ortho crosshair lines
121
- self.vLine = pg.InfiniteLine(angle=90, movable=False)
122
- self.hLine = pg.InfiniteLine(angle=0, movable=False)
123
- self.vLineOrtho = [
124
- pg.InfiniteLine(angle=90, movable=False),
125
- pg.InfiniteLine(angle=90, movable=False)
126
- ]
127
- self.hLineOrtho = [
128
- pg.InfiniteLine(angle=0, movable=False),
129
- pg.InfiniteLine(angle=0, movable=False)
130
- ]
131
- self.make_orthoviews()
132
-
133
- # z scrollbar underneath
134
- self.scroll = QScrollBar(QtCore.Qt.Horizontal)
135
- self.scroll.setMaximum(10)
136
- self.scroll.valueChanged.connect(self.move_in_Z)
137
- self.lmain.addWidget(self.scroll, 40, 9, 1, 30)
138
-
139
- b = 22
140
-
141
- label = QLabel("stitch\nthreshold:")
142
- label.setToolTip(
143
- "for 3D volumes, turn on stitch_threshold to stitch masks across planes instead of running cellpose in 3D (see docs for details)"
144
- )
145
- label.setFont(self.medfont)
146
- self.segBoxG.addWidget(label, b, 0, 1, 4)
147
- self.stitch_threshold = QLineEdit()
148
- self.stitch_threshold.setText("0.0")
149
- self.stitch_threshold.setFixedWidth(30)
150
- self.stitch_threshold.setFont(self.medfont)
151
- self.stitch_threshold.setToolTip(
152
- "for 3D volumes, turn on stitch_threshold to stitch masks across planes instead of running cellpose in 3D (see docs for details)"
153
- )
154
- self.segBoxG.addWidget(self.stitch_threshold, b, 3, 1, 1)
155
-
156
- label = QLabel("flow3D\nsmooth:")
157
- label.setToolTip(
158
- "for 3D volumes, smooth flows by a Gaussian with standard deviation flow3D_smooth (see docs for details)"
159
- )
160
- label.setFont(self.medfont)
161
- self.segBoxG.addWidget(label, b, 4, 1, 3)
162
- self.flow3D_smooth = QLineEdit()
163
- self.flow3D_smooth.setText("0.0")
164
- self.flow3D_smooth.setFixedWidth(30)
165
- self.flow3D_smooth.setFont(self.medfont)
166
- self.flow3D_smooth.setToolTip(
167
- "for 3D volumes, smooth flows by a Gaussian with standard deviation flow3D_smooth (see docs for details)"
168
- )
169
- self.segBoxG.addWidget(self.flow3D_smooth, b, 7, 1, 1)
170
-
171
- b+=1
172
- label = QLabel("anisotropy:")
173
- label.setToolTip(
174
- "for 3D volumes, increase in sampling in Z vs XY as a ratio, e.g. set set to 2.0 if Z is sampled half as dense as X or Y (see docs for details)"
175
- )
176
- label.setFont(self.medfont)
177
- self.segBoxG.addWidget(label, b, 0, 1, 3)
178
- self.anisotropy = QLineEdit()
179
- self.anisotropy.setText("1.0")
180
- self.anisotropy.setFixedWidth(30)
181
- self.anisotropy.setFont(self.medfont)
182
- self.anisotropy.setToolTip(
183
- "for 3D volumes, increase in sampling in Z vs XY as a ratio, e.g. set set to 2.0 if Z is sampled half as dense as X or Y (see docs for details)"
184
- )
185
- self.segBoxG.addWidget(self.anisotropy, b, 3, 1, 1)
186
-
187
- b+=1
188
- label = QLabel("min\nsize:")
189
- label.setToolTip(
190
- "all masks less than this size in pixels (volume) will be removed"
191
- )
192
- label.setFont(self.medfont)
193
- self.segBoxG.addWidget(label, b, 0, 1, 4)
194
- self.min_size = QLineEdit()
195
- self.min_size.setText("15")
196
- self.min_size.setFixedWidth(50)
197
- self.min_size.setFont(self.medfont)
198
- self.min_size.setToolTip(
199
- "all masks less than this size in pixels (volume) will be removed"
200
- )
201
- self.segBoxG.addWidget(self.min_size, b, 3, 1, 1)
202
-
203
- b += 1
204
- self.orthobtn = QCheckBox("ortho")
205
- self.orthobtn.setToolTip("activate orthoviews with 3D image")
206
- self.orthobtn.setFont(self.medfont)
207
- self.orthobtn.setChecked(False)
208
- self.l0.addWidget(self.orthobtn, b, 0, 1, 2)
209
- self.orthobtn.toggled.connect(self.toggle_ortho)
210
-
211
- label = QLabel("dz:")
212
- label.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
213
- label.setFont(self.medfont)
214
- self.l0.addWidget(label, b, 2, 1, 1)
215
- self.dz = 10
216
- self.dzedit = QLineEdit()
217
- self.dzedit.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
218
- self.dzedit.setText(str(self.dz))
219
- self.dzedit.returnPressed.connect(self.update_ortho)
220
- self.dzedit.setFixedWidth(40)
221
- self.dzedit.setFont(self.medfont)
222
- self.l0.addWidget(self.dzedit, b, 3, 1, 2)
223
-
224
- label = QLabel("z-aspect:")
225
- label.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
226
- label.setFont(self.medfont)
227
- self.l0.addWidget(label, b, 5, 1, 2)
228
- self.zaspect = 1.0
229
- self.zaspectedit = QLineEdit()
230
- self.zaspectedit.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
231
- self.zaspectedit.setText(str(self.zaspect))
232
- self.zaspectedit.returnPressed.connect(self.update_ortho)
233
- self.zaspectedit.setFixedWidth(40)
234
- self.zaspectedit.setFont(self.medfont)
235
- self.l0.addWidget(self.zaspectedit, b, 7, 1, 2)
236
-
237
- b += 1
238
- # add z position underneath
239
- self.currentZ = 0
240
- label = QLabel("Z:")
241
- label.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
242
- self.l0.addWidget(label, b, 5, 1, 2)
243
- self.zpos = QLineEdit()
244
- self.zpos.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
245
- self.zpos.setText(str(self.currentZ))
246
- self.zpos.returnPressed.connect(self.update_ztext)
247
- self.zpos.setFixedWidth(40)
248
- self.zpos.setFont(self.medfont)
249
- self.l0.addWidget(self.zpos, b, 7, 1, 2)
250
-
251
- # if called with image, load it
252
- if image is not None:
253
- self.filename = image
254
- io._load_image(self, self.filename, load_3D=True)
255
-
256
- self.load_3D = True
257
-
258
- def add_mask(self, points=None, color=(100, 200, 50), dense=True):
259
- # points is list of strokes
260
-
261
- points_all = np.concatenate(points, axis=0)
262
-
263
- # loop over z values
264
- median = []
265
- zdraw = np.unique(points_all[:, 0])
266
- zrange = np.arange(zdraw.min(), zdraw.max() + 1, 1, int)
267
- zmin = zdraw.min()
268
- pix = np.zeros((2, 0), "uint16")
269
- mall = np.zeros((len(zrange), self.Ly, self.Lx), "bool")
270
- k = 0
271
- for z in zdraw:
272
- ars, acs, vrs, vcs = np.zeros(0, "int"), np.zeros(0, "int"), np.zeros(
273
- 0, "int"), np.zeros(0, "int")
274
- for stroke in points:
275
- stroke = np.concatenate(stroke, axis=0).reshape(-1, 4)
276
- iz = stroke[:, 0] == z
277
- vr = stroke[iz, 1]
278
- vc = stroke[iz, 2]
279
- if iz.sum() > 0:
280
- # get points inside drawn points
281
- mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), "uint8")
282
- pts = np.stack((vc - vc.min() + 2, vr - vr.min() + 2),
283
- axis=-1)[:, np.newaxis, :]
284
- mask = cv2.fillPoly(mask, [pts], (255, 0, 0))
285
- ar, ac = np.nonzero(mask)
286
- ar, ac = ar + vr.min() - 2, ac + vc.min() - 2
287
- # get dense outline
288
- contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
289
- cv2.CHAIN_APPROX_NONE)
290
- pvc, pvr = contours[-2][0].squeeze().T
291
- vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2
292
- # concatenate all points
293
- ar, ac = np.hstack((np.vstack((vr, vc)), np.vstack((ar, ac))))
294
- # if these pixels are overlapping with another cell, reassign them
295
- ioverlap = self.cellpix[z][ar, ac] > 0
296
- if (~ioverlap).sum() < 8:
297
- print("ERROR: cell too small without overlaps, not drawn")
298
- return None
299
- elif ioverlap.sum() > 0:
300
- ar, ac = ar[~ioverlap], ac[~ioverlap]
301
- # compute outline of new mask
302
- mask = np.zeros((np.ptp(ar) + 4, np.ptp(ac) + 4), "uint8")
303
- mask[ar - ar.min() + 2, ac - ac.min() + 2] = 1
304
- contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
305
- cv2.CHAIN_APPROX_NONE)
306
- pvc, pvr = contours[-2][0].squeeze().T
307
- vr, vc = pvr + ar.min() - 2, pvc + ac.min() - 2
308
- ars = np.concatenate((ars, ar), axis=0)
309
- acs = np.concatenate((acs, ac), axis=0)
310
- vrs = np.concatenate((vrs, vr), axis=0)
311
- vcs = np.concatenate((vcs, vc), axis=0)
312
- self.draw_mask(z, ars, acs, vrs, vcs, color)
313
-
314
- median.append(np.array([np.median(ars), np.median(acs)]))
315
- mall[z - zmin, ars, acs] = True
316
- pix = np.append(pix, np.vstack((ars, acs)), axis=-1)
317
-
318
- mall = mall[:, pix[0].min():pix[0].max() + 1,
319
- pix[1].min():pix[1].max() + 1].astype("float32")
320
- ymin, xmin = pix[0].min(), pix[1].min()
321
- if len(zdraw) > 1:
322
- mall, zfill = interpZ(mall, zdraw - zmin)
323
- for z in zfill:
324
- mask = mall[z].copy()
325
- ar, ac = np.nonzero(mask)
326
- ioverlap = self.cellpix[z + zmin][ar + ymin, ac + xmin] > 0
327
- if (~ioverlap).sum() < 5:
328
- print("WARNING: stroke on plane %d not included due to overlaps" %
329
- z)
330
- elif ioverlap.sum() > 0:
331
- mask[ar[ioverlap], ac[ioverlap]] = 0
332
- ar, ac = ar[~ioverlap], ac[~ioverlap]
333
- # compute outline of mask
334
- outlines = masks_to_outlines(mask)
335
- vr, vc = np.nonzero(outlines)
336
- vr, vc = vr + ymin, vc + xmin
337
- ar, ac = ar + ymin, ac + xmin
338
- self.draw_mask(z + zmin, ar, ac, vr, vc, color)
339
-
340
- self.zdraw.append(zdraw)
341
-
342
- return median
343
-
344
- def move_in_Z(self):
345
- if self.loaded:
346
- self.currentZ = min(self.NZ, max(0, int(self.scroll.value())))
347
- self.zpos.setText(str(self.currentZ))
348
- self.update_plot()
349
- self.draw_layer()
350
- self.update_layer()
351
-
352
- def make_orthoviews(self):
353
- self.pOrtho, self.imgOrtho, self.layerOrtho = [], [], []
354
- for j in range(2):
355
- self.pOrtho.append(
356
- pg.ViewBox(lockAspect=True, name=f"plotOrtho{j}",
357
- border=[100, 100, 100], invertY=True, enableMouse=False))
358
- self.pOrtho[j].setMenuEnabled(False)
359
-
360
- self.imgOrtho.append(pg.ImageItem(viewbox=self.pOrtho[j], parent=self))
361
- self.imgOrtho[j].autoDownsample = False
362
-
363
- self.layerOrtho.append(pg.ImageItem(viewbox=self.pOrtho[j], parent=self))
364
- self.layerOrtho[j].setLevels([0., 255.])
365
-
366
- #self.pOrtho[j].scene().contextMenuItem = self.pOrtho[j]
367
- self.pOrtho[j].addItem(self.imgOrtho[j])
368
- self.pOrtho[j].addItem(self.layerOrtho[j])
369
- self.pOrtho[j].addItem(self.vLineOrtho[j], ignoreBounds=False)
370
- self.pOrtho[j].addItem(self.hLineOrtho[j], ignoreBounds=False)
371
-
372
- self.pOrtho[0].linkView(self.pOrtho[0].YAxis, self.p0)
373
- self.pOrtho[1].linkView(self.pOrtho[1].XAxis, self.p0)
374
-
375
- def add_orthoviews(self):
376
- self.yortho = self.Ly // 2
377
- self.xortho = self.Lx // 2
378
- if self.NZ > 1:
379
- self.update_ortho()
380
-
381
- self.win.addItem(self.pOrtho[0], 0, 1, rowspan=1, colspan=1)
382
- self.win.addItem(self.pOrtho[1], 1, 0, rowspan=1, colspan=1)
383
-
384
- qGraphicsGridLayout = self.win.ci.layout
385
- qGraphicsGridLayout.setColumnStretchFactor(0, 2)
386
- qGraphicsGridLayout.setColumnStretchFactor(1, 1)
387
- qGraphicsGridLayout.setRowStretchFactor(0, 2)
388
- qGraphicsGridLayout.setRowStretchFactor(1, 1)
389
-
390
- self.pOrtho[0].setYRange(0, self.Lx)
391
- self.pOrtho[0].setXRange(-self.dz / 3, self.dz * 2 + self.dz / 3)
392
- self.pOrtho[1].setYRange(-self.dz / 3, self.dz * 2 + self.dz / 3)
393
- self.pOrtho[1].setXRange(0, self.Ly)
394
-
395
- self.p0.addItem(self.vLine, ignoreBounds=False)
396
- self.p0.addItem(self.hLine, ignoreBounds=False)
397
- self.p0.setYRange(0, self.Lx)
398
- self.p0.setXRange(0, self.Ly)
399
-
400
- self.win.show()
401
- self.show()
402
-
403
- def remove_orthoviews(self):
404
- self.win.removeItem(self.pOrtho[0])
405
- self.win.removeItem(self.pOrtho[1])
406
- self.p0.removeItem(self.vLine)
407
- self.p0.removeItem(self.hLine)
408
- self.win.show()
409
- self.show()
410
-
411
- def update_crosshairs(self):
412
- self.yortho = min(self.Ly - 1, max(0, int(self.yortho)))
413
- self.xortho = min(self.Lx - 1, max(0, int(self.xortho)))
414
- self.vLine.setPos(self.xortho)
415
- self.hLine.setPos(self.yortho)
416
- self.vLineOrtho[1].setPos(self.xortho)
417
- self.hLineOrtho[1].setPos(self.zc)
418
- self.vLineOrtho[0].setPos(self.zc)
419
- self.hLineOrtho[0].setPos(self.yortho)
420
-
421
- def update_ortho(self):
422
- if self.NZ > 1 and self.orthobtn.isChecked():
423
- dzcurrent = self.dz
424
- self.dz = min(100, max(3, int(self.dzedit.text())))
425
- self.zaspect = max(0.01, min(100., float(self.zaspectedit.text())))
426
- self.dzedit.setText(str(self.dz))
427
- self.zaspectedit.setText(str(self.zaspect))
428
- if self.dz != dzcurrent:
429
- self.pOrtho[0].setXRange(-self.dz / 3, self.dz * 2 + self.dz / 3)
430
- self.pOrtho[1].setYRange(-self.dz / 3, self.dz * 2 + self.dz / 3)
431
- dztot = min(self.NZ, self.dz * 2)
432
- y = self.yortho
433
- x = self.xortho
434
- z = self.currentZ
435
- if dztot == self.NZ:
436
- zmin, zmax = 0, self.NZ
437
- else:
438
- if z - self.dz < 0:
439
- zmin = 0
440
- zmax = zmin + self.dz * 2
441
- elif z + self.dz >= self.NZ:
442
- zmax = self.NZ
443
- zmin = zmax - self.dz * 2
444
- else:
445
- zmin, zmax = z - self.dz, z + self.dz
446
- self.zc = z - zmin
447
- self.update_crosshairs()
448
- if self.view == 0 or self.view == 4:
449
- for j in range(2):
450
- if j == 0:
451
- if self.view == 0:
452
- image = self.stack[zmin:zmax, :, x].transpose(1, 0, 2).copy()
453
- else:
454
- image = self.stack_filtered[zmin:zmax, :,
455
- x].transpose(1, 0, 2).copy()
456
- else:
457
- image = self.stack[
458
- zmin:zmax,
459
- y, :].copy() if self.view == 0 else self.stack_filtered[zmin:zmax,
460
- y, :].copy()
461
- if self.nchan == 1:
462
- # show single channel
463
- image = image[..., 0]
464
- if self.color == 0:
465
- self.imgOrtho[j].setImage(image, autoLevels=False, lut=None)
466
- if self.nchan > 1:
467
- levels = np.array([
468
- self.saturation[0][self.currentZ],
469
- self.saturation[1][self.currentZ],
470
- self.saturation[2][self.currentZ]
471
- ])
472
- self.imgOrtho[j].setLevels(levels)
473
- else:
474
- self.imgOrtho[j].setLevels(
475
- self.saturation[0][self.currentZ])
476
- elif self.color > 0 and self.color < 4:
477
- if self.nchan > 1:
478
- image = image[..., self.color - 1]
479
- self.imgOrtho[j].setImage(image, autoLevels=False,
480
- lut=self.cmap[self.color])
481
- if self.nchan > 1:
482
- self.imgOrtho[j].setLevels(
483
- self.saturation[self.color - 1][self.currentZ])
484
- else:
485
- self.imgOrtho[j].setLevels(
486
- self.saturation[0][self.currentZ])
487
- elif self.color == 4:
488
- if image.ndim > 2:
489
- image = image.astype("float32").mean(axis=2).astype("uint8")
490
- self.imgOrtho[j].setImage(image, autoLevels=False, lut=None)
491
- self.imgOrtho[j].setLevels(self.saturation[0][self.currentZ])
492
- elif self.color == 5:
493
- if image.ndim > 2:
494
- image = image.astype("float32").mean(axis=2).astype("uint8")
495
- self.imgOrtho[j].setImage(image, autoLevels=False,
496
- lut=self.cmap[0])
497
- self.imgOrtho[j].setLevels(self.saturation[0][self.currentZ])
498
- self.pOrtho[0].setAspectLocked(lock=True, ratio=self.zaspect)
499
- self.pOrtho[1].setAspectLocked(lock=True, ratio=1. / self.zaspect)
500
-
501
- else:
502
- image = np.zeros((10, 10), "uint8")
503
- self.imgOrtho[0].setImage(image, autoLevels=False, lut=None)
504
- self.imgOrtho[0].setLevels([0.0, 255.0])
505
- self.imgOrtho[1].setImage(image, autoLevels=False, lut=None)
506
- self.imgOrtho[1].setLevels([0.0, 255.0])
507
-
508
- zrange = zmax - zmin
509
- self.layer_ortho = [
510
- np.zeros((self.Ly, zrange, 4), "uint8"),
511
- np.zeros((zrange, self.Lx, 4), "uint8")
512
- ]
513
- if self.masksOn:
514
- for j in range(2):
515
- if j == 0:
516
- cp = self.cellpix[zmin:zmax, :, x].T
517
- else:
518
- cp = self.cellpix[zmin:zmax, y]
519
- self.layer_ortho[j][..., :3] = self.cellcolors[cp, :]
520
- self.layer_ortho[j][..., 3] = self.opacity * (cp > 0).astype("uint8")
521
- if self.selected > 0:
522
- self.layer_ortho[j][cp == self.selected] = np.array(
523
- [255, 255, 255, self.opacity])
524
-
525
- if self.outlinesOn:
526
- for j in range(2):
527
- if j == 0:
528
- op = self.outpix[zmin:zmax, :, x].T
529
- else:
530
- op = self.outpix[zmin:zmax, y]
531
- self.layer_ortho[j][op > 0] = np.array(self.outcolor).astype("uint8")
532
-
533
- for j in range(2):
534
- self.layerOrtho[j].setImage(self.layer_ortho[j])
535
- self.win.show()
536
- self.show()
537
-
538
- def toggle_ortho(self):
539
- if self.orthobtn.isChecked():
540
- self.add_orthoviews()
541
- else:
542
- self.remove_orthoviews()
543
-
544
- def plot_clicked(self, event):
545
- if event.button()==QtCore.Qt.LeftButton \
546
- and not event.modifiers() & (QtCore.Qt.ShiftModifier | QtCore.Qt.AltModifier)\
547
- and not self.removing_region:
548
- if event.double():
549
- try:
550
- self.p0.setYRange(0, self.Ly + self.pr)
551
- except:
552
- self.p0.setYRange(0, self.Ly)
553
- self.p0.setXRange(0, self.Lx)
554
- elif self.loaded and not self.in_stroke:
555
- if self.orthobtn.isChecked():
556
- items = self.win.scene().items(event.scenePos())
557
- for x in items:
558
- if x == self.p0:
559
- pos = self.p0.mapSceneToView(event.scenePos())
560
- x = int(pos.x())
561
- y = int(pos.y())
562
- if y >= 0 and y < self.Ly and x >= 0 and x < self.Lx:
563
- self.yortho = y
564
- self.xortho = x
565
- self.update_ortho()
566
-
567
- def update_plot(self):
568
- super().update_plot()
569
- if self.NZ > 1 and self.orthobtn.isChecked():
570
- self.update_ortho()
571
- self.win.show()
572
- self.show()
573
-
574
- def keyPressEvent(self, event):
575
- if self.loaded:
576
- if not (event.modifiers() &
577
- (QtCore.Qt.ControlModifier | QtCore.Qt.ShiftModifier |
578
- QtCore.Qt.AltModifier) or self.in_stroke):
579
- updated = False
580
- if len(self.current_point_set) > 0:
581
- if event.key() == QtCore.Qt.Key_Return:
582
- self.add_set()
583
- if self.NZ > 1:
584
- if event.key() == QtCore.Qt.Key_Left:
585
- self.currentZ = max(0, self.currentZ - 1)
586
- self.scroll.setValue(self.currentZ)
587
- updated = True
588
- elif event.key() == QtCore.Qt.Key_Right:
589
- self.currentZ = min(self.NZ - 1, self.currentZ + 1)
590
- self.scroll.setValue(self.currentZ)
591
- updated = True
592
- else:
593
- nviews = self.ViewDropDown.count() - 1
594
- nviews += int(
595
- self.ViewDropDown.model().item(self.ViewDropDown.count() -
596
- 1).isEnabled())
597
- if event.key() == QtCore.Qt.Key_X:
598
- self.MCheckBox.toggle()
599
- if event.key() == QtCore.Qt.Key_Z:
600
- self.OCheckBox.toggle()
601
- if event.key() == QtCore.Qt.Key_Left or event.key(
602
- ) == QtCore.Qt.Key_A:
603
- self.currentZ = max(0, self.currentZ - 1)
604
- self.scroll.setValue(self.currentZ)
605
- updated = True
606
- elif event.key() == QtCore.Qt.Key_Right or event.key(
607
- ) == QtCore.Qt.Key_D:
608
- self.currentZ = min(self.NZ - 1, self.currentZ + 1)
609
- self.scroll.setValue(self.currentZ)
610
- updated = True
611
- elif event.key() == QtCore.Qt.Key_PageDown:
612
- self.view = (self.view + 1) % (nviews)
613
- self.ViewDropDown.setCurrentIndex(self.view)
614
- elif event.key() == QtCore.Qt.Key_PageUp:
615
- self.view = (self.view - 1) % (nviews)
616
- self.ViewDropDown.setCurrentIndex(self.view)
617
-
618
- # can change background or stroke size if cell not finished
619
- if event.key() == QtCore.Qt.Key_Up or event.key() == QtCore.Qt.Key_W:
620
- self.color = (self.color - 1) % (6)
621
- self.RGBDropDown.setCurrentIndex(self.color)
622
- elif event.key() == QtCore.Qt.Key_Down or event.key(
623
- ) == QtCore.Qt.Key_S:
624
- self.color = (self.color + 1) % (6)
625
- self.RGBDropDown.setCurrentIndex(self.color)
626
- elif event.key() == QtCore.Qt.Key_R:
627
- if self.color != 1:
628
- self.color = 1
629
- else:
630
- self.color = 0
631
- self.RGBDropDown.setCurrentIndex(self.color)
632
- elif event.key() == QtCore.Qt.Key_G:
633
- if self.color != 2:
634
- self.color = 2
635
- else:
636
- self.color = 0
637
- self.RGBDropDown.setCurrentIndex(self.color)
638
- elif event.key() == QtCore.Qt.Key_B:
639
- if self.color != 3:
640
- self.color = 3
641
- else:
642
- self.color = 0
643
- self.RGBDropDown.setCurrentIndex(self.color)
644
- elif (event.key() == QtCore.Qt.Key_Comma or
645
- event.key() == QtCore.Qt.Key_Period):
646
- count = self.BrushChoose.count()
647
- gci = self.BrushChoose.currentIndex()
648
- if event.key() == QtCore.Qt.Key_Comma:
649
- gci = max(0, gci - 1)
650
- else:
651
- gci = min(count - 1, gci + 1)
652
- self.BrushChoose.setCurrentIndex(gci)
653
- self.brush_choose()
654
- if not updated:
655
- self.update_plot()
656
- if event.key() == QtCore.Qt.Key_Minus or event.key() == QtCore.Qt.Key_Equal:
657
- self.p0.keyPressEvent(event)
658
-
659
- def update_ztext(self):
660
- zpos = self.currentZ
661
- try:
662
- zpos = int(self.zpos.text())
663
- except:
664
- print("ERROR: zposition is not a number")
665
- self.currentZ = max(0, min(self.NZ - 1, zpos))
666
- self.zpos.setText(str(self.currentZ))
667
- self.scroll.setValue(self.currentZ)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/seg_post_model/cellpose/gui/guihelpwindowtext.html DELETED
@@ -1,143 +0,0 @@
1
- <qt>
2
- <p class="has-line-data" data-line-start="5" data-line-end="6">
3
- <b>Main GUI mouse controls:</b>
4
- </p>
5
- <ul>
6
- <li class="has-line-data" data-line-start="7" data-line-end="8">Pan = left-click + drag</li>
7
- <li class="has-line-data" data-line-start="8" data-line-end="9">Zoom = scroll wheel (or +/= and - buttons)</li>
8
- <li class="has-line-data" data-line-start="9" data-line-end="10">Full view = double left-click</li>
9
- <li class="has-line-data" data-line-start="10" data-line-end="11">Select mask = left-click on mask</li>
10
- <li class="has-line-data" data-line-start="11" data-line-end="12">Delete mask = Ctrl (or COMMAND on Mac) +
11
- left-click
12
- </li>
13
- <li class="has-line-data" data-line-start="11" data-line-end="12">Merge masks = Alt + left-click (will merge
14
- last two)
15
- </li>
16
- <li class="has-line-data" data-line-start="12" data-line-end="13">Start draw mask = right-click</li>
17
- <li class="has-line-data" data-line-start="13" data-line-end="15">End draw mask = right-click, or return to
18
- circle at beginning
19
- </li>
20
- </ul>
21
- <p class="has-line-data" data-line-start="15" data-line-end="16">Overlaps in masks are NOT allowed. If you
22
- draw a mask on top of another mask, it is cropped so that it doesn’t overlap with the old mask. Masks in 2D
23
- should be single strokes (single stroke is checked). If you want to draw masks in 3D (experimental), then
24
- you can turn this option off and draw a stroke on each plane with the cell and then press ENTER. 3D
25
- labelling will fill in planes that you have not labelled so that you do not have to as densely label.
26
- </p>
27
- <p class="has-line-data" data-line-start="17" data-line-end="18"> <b>!NOTE!:</b> The GUI automatically saves after
28
- you draw a mask in 2D but NOT after 3D mask drawing and NOT after segmentation. Save in the file menu or
29
- with Ctrl+S. The output file is in the same folder as the loaded image with <code>_seg.npy</code> appended.
30
- </p>
31
-
32
- <p class="has-line-data" data-line-start="19" data-line-end="20"> <b>Bulk Mask Deletion</b>
33
- Clicking the 'delete multiple' button will allow you to select and delete multiple masks at once.
34
- Masks can be deselected by clicking on them again. Once you have selected all the masks you want to delete,
35
- click the 'done' button to delete them.
36
- <br>
37
- <br>
38
- Alternatively, you can create a rectangular region to delete a regions of masks by clicking the
39
- 'delete multiple' button, and then moving and/or resizing the region to select the masks you want to delete.
40
- Once you have selected the masks you want to delete, click the 'done' button to delete them.
41
- <br>
42
- <br>
43
- At any point in the process, you can click the 'cancel' button to cancel the bulk deletion.
44
- </p>
45
- <hr>
46
- <table class="table table-striped table-bordered">
47
- <br>
48
- <br>
49
- FYI there are tooltips throughout the GUI (hover over text to see)
50
- <br>
51
- <thead>
52
- <tr>
53
- <th>Keyboard shortcuts</th>
54
- <th>Description</th>
55
- </tr>
56
- </thead>
57
- <tbody>
58
- <tr>
59
- <td>=/+ button // - button</td>
60
- <td>zoom in // zoom out</td>
61
- </tr>
62
- <tr>
63
- <td>CTRL+Z</td>
64
- <td>undo previously drawn mask/stroke</td>
65
- </tr>
66
- <tr>
67
- <td>CTRL+Y</td>
68
- <td>undo remove mask</td>
69
- </tr>
70
- <tr>
71
- <td>CTRL+0</td>
72
- <td>clear all masks</td>
73
- </tr>
74
- <tr>
75
- <td>CTRL+L</td>
76
- <td>load image (can alternatively drag and drop image)</td>
77
- </tr>
78
- <tr>
79
- <td>CTRL+S</td>
80
- <td>SAVE MASKS IN IMAGE to <code>_seg.npy</code> file</td>
81
- </tr>
82
- <tr>
83
- <td>CTRL+T</td>
84
- <td>train model using _seg.npy files in folder
85
- </tr>
86
- <tr>
87
- <td>CTRL+P</td>
88
- <td>load <code>_seg.npy</code> file (note: it will load automatically with image if it exists)</td>
89
- </tr>
90
- <tr>
91
- <td>CTRL+M</td>
92
- <td>load masks file (must be same size as image with 0 for NO mask, and 1,2,3… for masks)</td>
93
- </tr>
94
- <tr>
95
- <td>CTRL+N</td>
96
- <td>save masks as PNG</td>
97
- </tr>
98
- <tr>
99
- <td>CTRL+R</td>
100
- <td>save ROIs to native ImageJ ROI format</td>
101
- </tr>
102
- <tr>
103
- <td>CTRL+F</td>
104
- <td>save flows to image file</td>
105
- </tr>
106
- <tr>
107
- <td>A/D or LEFT/RIGHT</td>
108
- <td>cycle through images in current directory</td>
109
- </tr>
110
- <tr>
111
- <td>W/S or UP/DOWN</td>
112
- <td>change color (RGB/gray/red/green/blue)</td>
113
- </tr>
114
- <tr>
115
- <td>R / G / B</td>
116
- <td>toggle between RGB and Red or Green or Blue</td>
117
- </tr>
118
- <tr>
119
- <td>PAGE-UP / PAGE-DOWN</td>
120
- <td>change to flows and cell prob views (if segmentation computed)</td>
121
- </tr>
122
- <tr>
123
- <td>X</td>
124
- <td>turn masks ON or OFF</td>
125
- </tr>
126
- <tr>
127
- <td>Z</td>
128
- <td>toggle outlines ON or OFF</td>
129
- </tr>
130
- <tr>
131
- <td>, / .</td>
132
- <td>increase / decrease brush size for drawing masks</td>
133
- </tr>
134
- </tbody>
135
- </table>
136
- <p class="has-line-data" data-line-start="36" data-line-end="37"><strong>Segmentation options
137
- (2D only) </strong></p>
138
- <p class="has-line-data" data-line-start="38" data-line-end="39">use GPU: if you have specially
139
- installed the cuda version of torch, then you can activate this. Due to the size of the
140
- transformer network, it will greatly speed up the processing time.</p>
141
- <p class="has-line-data" data-line-start="40" data-line-end="41">There are no channel options
142
- in v4.0.1+ since all 3 channels are used for segmentation. </p>
143
- </qt>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/seg_post_model/cellpose/gui/guiparts.py DELETED
@@ -1,793 +0,0 @@
1
- """
2
- Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
- """
4
- from qtpy import QtGui, QtCore
5
- from qtpy.QtGui import QPixmap, QDoubleValidator
6
- from qtpy.QtWidgets import QWidget, QDialog, QGridLayout, QPushButton, QLabel, QLineEdit, QDialogButtonBox, QComboBox, QCheckBox, QVBoxLayout
7
- import pyqtgraph as pg
8
- import numpy as np
9
- import pathlib, os
10
-
11
-
12
- def stylesheet():
13
- return """
14
- QToolTip {
15
- background-color: black;
16
- color: white;
17
- border: black solid 1px
18
- }
19
- QComboBox {color: white;
20
- background-color: rgb(40,40,40);}
21
- QComboBox::item:enabled { color: white;
22
- background-color: rgb(40,40,40);
23
- selection-color: white;
24
- selection-background-color: rgb(50,100,50);}
25
- QComboBox::item:!enabled {
26
- background-color: rgb(40,40,40);
27
- color: rgb(100,100,100);
28
- }
29
- QScrollArea > QWidget > QWidget
30
- {
31
- background: transparent;
32
- border: none;
33
- margin: 0px 0px 0px 0px;
34
- }
35
-
36
- QGroupBox
37
- { border: 1px solid white; color: rgb(255,255,255);
38
- border-radius: 6px;
39
- margin-top: 8px;
40
- padding: 0px 0px;}
41
-
42
- QPushButton:pressed {Text-align: center;
43
- background-color: rgb(150,50,150);
44
- border-color: white;
45
- color:white;}
46
- QToolTip {
47
- background-color: black;
48
- color: white;
49
- border: black solid 1px
50
- }
51
- QPushButton:!pressed {Text-align: center;
52
- background-color: rgb(50,50,50);
53
- border-color: white;
54
- color:white;}
55
- QToolTip {
56
- background-color: black;
57
- color: white;
58
- border: black solid 1px
59
- }
60
- QPushButton:disabled {Text-align: center;
61
- background-color: rgb(30,30,30);
62
- border-color: white;
63
- color:rgb(80,80,80);}
64
- QToolTip {
65
- background-color: black;
66
- color: white;
67
- border: black solid 1px
68
- }
69
-
70
- """
71
-
72
-
73
- class DarkPalette(QtGui.QPalette):
74
- """Class that inherits from pyqtgraph.QtGui.QPalette and renders dark colours for the application.
75
- (from pykilosort/kilosort4)
76
- """
77
-
78
- def __init__(self):
79
- QtGui.QPalette.__init__(self)
80
- self.setup()
81
-
82
- def setup(self):
83
- self.setColor(QtGui.QPalette.Window, QtGui.QColor(40, 40, 40))
84
- self.setColor(QtGui.QPalette.WindowText, QtGui.QColor(255, 255, 255))
85
- self.setColor(QtGui.QPalette.Base, QtGui.QColor(34, 27, 24))
86
- self.setColor(QtGui.QPalette.AlternateBase, QtGui.QColor(53, 50, 47))
87
- self.setColor(QtGui.QPalette.ToolTipBase, QtGui.QColor(255, 255, 255))
88
- self.setColor(QtGui.QPalette.ToolTipText, QtGui.QColor(255, 255, 255))
89
- self.setColor(QtGui.QPalette.Text, QtGui.QColor(255, 255, 255))
90
- self.setColor(QtGui.QPalette.Button, QtGui.QColor(53, 50, 47))
91
- self.setColor(QtGui.QPalette.ButtonText, QtGui.QColor(255, 255, 255))
92
- self.setColor(QtGui.QPalette.BrightText, QtGui.QColor(255, 0, 0))
93
- self.setColor(QtGui.QPalette.Link, QtGui.QColor(42, 130, 218))
94
- self.setColor(QtGui.QPalette.Highlight, QtGui.QColor(42, 130, 218))
95
- self.setColor(QtGui.QPalette.HighlightedText, QtGui.QColor(0, 0, 0))
96
- self.setColor(QtGui.QPalette.Disabled, QtGui.QPalette.Text,
97
- QtGui.QColor(128, 128, 128))
98
- self.setColor(
99
- QtGui.QPalette.Disabled,
100
- QtGui.QPalette.ButtonText,
101
- QtGui.QColor(128, 128, 128),
102
- )
103
- self.setColor(
104
- QtGui.QPalette.Disabled,
105
- QtGui.QPalette.WindowText,
106
- QtGui.QColor(128, 128, 128),
107
- )
108
-
109
-
110
- # def create_channel_choose():
111
- # # choose channel
112
- # ChannelChoose = [QComboBox(), QComboBox()]
113
- # ChannelLabels = []
114
- # ChannelChoose[0].addItems(["gray", "red", "green", "blue"])
115
- # ChannelChoose[1].addItems(["none", "red", "green", "blue"])
116
- # cstr = ["chan to segment:", "chan2 (optional): "]
117
- # for i in range(2):
118
- # ChannelLabels.append(QLabel(cstr[i]))
119
- # if i == 0:
120
- # ChannelLabels[i].setToolTip(
121
- # "this is the channel in which the cytoplasm or nuclei exist \
122
- # that you want to segment")
123
- # ChannelChoose[i].setToolTip(
124
- # "this is the channel in which the cytoplasm or nuclei exist \
125
- # that you want to segment")
126
- # else:
127
- # ChannelLabels[i].setToolTip(
128
- # "if <em>cytoplasm</em> model is chosen, and you also have a \
129
- # nuclear channel, then choose the nuclear channel for this option")
130
- # ChannelChoose[i].setToolTip(
131
- # "if <em>cytoplasm</em> model is chosen, and you also have a \
132
- # nuclear channel, then choose the nuclear channel for this option")
133
-
134
- # return ChannelChoose, ChannelLabels
135
-
136
-
137
- class ModelButton(QPushButton):
138
-
139
- def __init__(self, parent, model_name, text):
140
- super().__init__()
141
- self.setEnabled(False)
142
- self.setText(text)
143
- self.setFont(parent.boldfont)
144
- self.clicked.connect(lambda: self.press(parent))
145
- self.model_name = "cpsam"
146
-
147
- def press(self, parent):
148
- parent.compute_segmentation(model_name="cpsam")
149
-
150
-
151
- class FilterButton(QPushButton):
152
-
153
- def __init__(self, parent, text):
154
- super().__init__()
155
- self.setEnabled(False)
156
- self.model_type = text
157
- self.setText(text)
158
- self.setFont(parent.medfont)
159
- self.clicked.connect(lambda: self.press(parent))
160
-
161
- def press(self, parent):
162
- if self.model_type == "filter":
163
- parent.restore = "filter"
164
- normalize_params = parent.get_normalize_params()
165
- if (normalize_params["sharpen_radius"] == 0 and
166
- normalize_params["smooth_radius"] == 0 and
167
- normalize_params["tile_norm_blocksize"] == 0):
168
- print(
169
- "GUI_ERROR: no filtering settings on (use custom filter settings)")
170
- parent.restore = None
171
- return
172
- parent.restore = self.model_type
173
- parent.compute_saturation()
174
- # elif self.model_type != "none":
175
- # parent.compute_denoise_model(model_type=self.model_type)
176
- else:
177
- parent.clear_restore()
178
- # parent.set_restore_button()
179
-
180
-
181
- class ObservableVariable(QtCore.QObject):
182
- valueChanged = QtCore.Signal(object)
183
-
184
- def __init__(self, initial=None):
185
- super().__init__()
186
- self._value = initial
187
-
188
- def set(self, new_value):
189
- """ Use this method to get emit the value changing and update the ROI count"""
190
- if new_value != self._value:
191
- self._value = new_value
192
- self.valueChanged.emit(new_value)
193
-
194
- def get(self):
195
- return self._value
196
-
197
- def __call__(self):
198
- return self._value
199
-
200
- def reset(self):
201
- self.set(0)
202
-
203
- def __iadd__(self, amount):
204
- if not isinstance(amount, (int, float)):
205
- raise TypeError("Value must be numeric.")
206
- self.set(self._value + amount)
207
- return self
208
-
209
- def __radd__(self, other):
210
- return other + self._value
211
-
212
- def __add__(self, other):
213
- return other + self._value
214
-
215
- def __isub__(self, amount):
216
- if not isinstance(amount, (int, float)):
217
- raise TypeError("Value must be numeric.")
218
- self.set(self._value - amount)
219
- return self
220
-
221
- def __str__(self):
222
- return str(self._value)
223
-
224
- def __lt__(self, x):
225
- return self._value < x
226
-
227
- def __gt__(self, x):
228
- return self._value > x
229
-
230
- def __eq__(self, x):
231
- return self._value == x
232
-
233
-
234
- class NormalizationSettings(QWidget):
235
- # TODO
236
- pass
237
-
238
-
239
- class SegmentationSettings(QWidget):
240
- """ Container for gui settings. Validation is done automatically so any attributes can
241
- be acessed without concern.
242
- """
243
- def __init__(self, font):
244
- super().__init__()
245
-
246
- # Put everything in a grid layout:
247
- grid_layout = QGridLayout()
248
- widget_container = QWidget()
249
- widget_container.setLayout(grid_layout)
250
- row = 0
251
-
252
- ########################### Diameter ###########################
253
- # TODO: Validate inputs
254
- diam_qlabel = QLabel("diameter:")
255
- diam_qlabel.setToolTip("diameter of cells in pixels. If not 30, image will be resized to this")
256
- diam_qlabel.setFont(font)
257
- grid_layout.addWidget(diam_qlabel, row, 0, 1, 2)
258
- self.diameter_box = QLineEdit()
259
- self.diameter_box.setToolTip("diameter of cells in pixels. If not blank, image will be resized relative to 30 pixel cell diameters")
260
- self.diameter_box.setFont(font)
261
- self.diameter_box.setFixedWidth(40)
262
- self.diameter_box.setText(' ')
263
- grid_layout.addWidget(self.diameter_box, row, 2, 1, 2)
264
-
265
- row += 1
266
-
267
- ########################### Flow threshold ###########################
268
- # TODO: Validate inputs
269
- flow_threshold_qlabel = QLabel("flow\nthreshold:")
270
- flow_threshold_qlabel.setToolTip("threshold on flow error to accept a mask (set higher to get more cells, e.g. in range from (0.1, 3.0), OR set to 0.0 to turn off so no cells discarded);\n press enter to recompute if model already run")
271
- flow_threshold_qlabel.setFont(font)
272
- grid_layout.addWidget(flow_threshold_qlabel, row, 0, 1, 2)
273
- self.flow_threshold_box = QLineEdit()
274
- self.flow_threshold_box.setText("0.4")
275
- self.flow_threshold_box.setFixedWidth(40)
276
- self.flow_threshold_box.setFont(font)
277
- grid_layout.addWidget(self.flow_threshold_box, row, 2, 1, 2)
278
- self.flow_threshold_box.setToolTip("threshold on flow error to accept a mask (set higher to get more cells, e.g. in range from (0.1, 3.0), OR set to 0.0 to turn off so no cells discarded);\n press enter to recompute if model already run")
279
-
280
- ########################### Cellprob threshold ###########################
281
- # TODO: Validate inputs
282
- cellprob_qlabel = QLabel("cellprob\nthreshold:")
283
- cellprob_qlabel.setToolTip("threshold on cellprob output to seed cell masks (set lower to include more pixels or higher to include fewer, e.g. in range from (-6, 6)); \n press enter to recompute if model already run")
284
- cellprob_qlabel.setFont(font)
285
- grid_layout.addWidget(cellprob_qlabel, row, 4, 1, 2)
286
- self.cellprob_threshold_box = QLineEdit()
287
- self.cellprob_threshold_box.setText("0.0")
288
- self.cellprob_threshold_box.setFixedWidth(40)
289
- self.cellprob_threshold_box.setFont(font)
290
- self.cellprob_threshold_box.setToolTip("threshold on cellprob output to seed cell masks (set lower to include more pixels or higher to include fewer, e.g. in range from (-6, 6)); \n press enter to recompute if model already run")
291
- grid_layout.addWidget(self.cellprob_threshold_box, row, 6, 1, 2)
292
-
293
- row += 1
294
-
295
- ########################### Norm percentiles ###########################
296
- norm_percentiles_qlabel = QLabel("norm percentiles:")
297
- norm_percentiles_qlabel.setToolTip("sets normalization percentiles for segmentation and denoising\n(pixels at lower percentile set to 0.0 and at upper set to 1.0 for network)")
298
- norm_percentiles_qlabel.setFont(font)
299
- grid_layout.addWidget(norm_percentiles_qlabel, row, 0, 1, 8)
300
-
301
- row += 1
302
- validator = QDoubleValidator(0.0, 100.0, 2)
303
- validator.setNotation(QDoubleValidator.StandardNotation)
304
-
305
- low_norm_qlabel = QLabel('lower:')
306
- low_norm_qlabel.setToolTip("pixels at this percentile set to 0 (default 1.0)")
307
- low_norm_qlabel.setFont(font)
308
- grid_layout.addWidget(low_norm_qlabel, row, 0, 1, 2)
309
- self.norm_percentile_low_box = QLineEdit()
310
- self.norm_percentile_low_box.setText("1.0")
311
- self.norm_percentile_low_box.setFont(font)
312
- self.norm_percentile_low_box.setFixedWidth(40)
313
- self.norm_percentile_low_box.setToolTip("pixels at this percentile set to 0 (default 1.0)")
314
- self.norm_percentile_low_box.setValidator(validator)
315
- self.norm_percentile_low_box.editingFinished.connect(self.validate_normalization_range)
316
- grid_layout.addWidget(self.norm_percentile_low_box, row, 2, 1, 1)
317
-
318
- high_norm_qlabel = QLabel('upper:')
319
- high_norm_qlabel.setToolTip("pixels at this percentile set to 1 (default 99.0)")
320
- high_norm_qlabel.setFont(font)
321
- grid_layout.addWidget(high_norm_qlabel, row, 4, 1, 2)
322
- self.norm_percentile_high_box = QLineEdit()
323
- self.norm_percentile_high_box.setText("99.0")
324
- self.norm_percentile_high_box.setFont(font)
325
- self.norm_percentile_high_box.setFixedWidth(40)
326
- self.norm_percentile_high_box.setToolTip("pixels at this percentile set to 1 (default 99.0)")
327
- self.norm_percentile_high_box.setValidator(validator)
328
- self.norm_percentile_high_box.editingFinished.connect(self.validate_normalization_range)
329
- grid_layout.addWidget(self.norm_percentile_high_box, row, 6, 1, 2)
330
-
331
- row += 1
332
-
333
- ########################### niter ###########################
334
- # TODO: change this to follow the same default logic as 'diameter' above
335
- # TODO: input validation
336
- niter_qlabel = QLabel("niter dynamics:")
337
- niter_qlabel.setFont(font)
338
- niter_qlabel.setToolTip("number of iterations for dynamics (0 uses default based on diameter); use 2000 for bacteria")
339
- grid_layout.addWidget(niter_qlabel, row, 0, 1, 4)
340
- self.niter_box = QLineEdit()
341
- self.niter_box.setText("0")
342
- self.niter_box.setFixedWidth(40)
343
- self.niter_box.setFont(font)
344
- self.niter_box.setToolTip("number of iterations for dynamics (0 uses default based on diameter); use 2000 for bacteria")
345
- grid_layout.addWidget(self.niter_box, row, 4, 1, 2)
346
-
347
- self.setLayout(grid_layout)
348
-
349
- def validate_normalization_range(self):
350
- low_text = self.norm_percentile_low_box.text()
351
- high_text = self.norm_percentile_high_box.text()
352
-
353
- if not low_text or low_text.isspace():
354
- self.norm_percentile_low_box.setText('1.0')
355
- low_text = '1.0'
356
- elif not high_text or high_text.isspace():
357
- self.norm_percentile_high_box.setText('1.0')
358
- high_text = '99.0'
359
-
360
- low = float(low_text)
361
- high = float(high_text)
362
-
363
- if low >= high:
364
- # Invalid: show error and mark fields
365
- self.norm_percentile_low_box.setStyleSheet("border: 1px solid red;")
366
- self.norm_percentile_high_box.setStyleSheet("border: 1px solid red;")
367
- else:
368
- # Valid: clear style
369
- self.norm_percentile_low_box.setStyleSheet("")
370
- self.norm_percentile_high_box.setStyleSheet("")
371
-
372
- @property
373
- def low_percentile(self):
374
- """ Also validate the low input by returning 1.0 if text doesn't work """
375
- low_text = self.norm_percentile_low_box.text()
376
- if not low_text or low_text.isspace():
377
- self.norm_percentile_low_box.setText('1.0')
378
- low_text = '1.0'
379
- return float(self.norm_percentile_low_box.text())
380
-
381
- @property
382
- def high_percentile(self):
383
- """ Also validate the high input by returning 99.0 if text doesn't work """
384
- high_text = self.norm_percentile_high_box.text()
385
- if not high_text or high_text.isspace():
386
- self.norm_percentile_high_box.setText('99.0')
387
- high_text = '99.0'
388
- return float(self.norm_percentile_high_box.text())
389
-
390
- @property
391
- def diameter(self):
392
- """ Get the diameter from the diameter box, if box isn't a number return None"""
393
- try:
394
- d = float(self.diameter_box.text())
395
- except ValueError:
396
- d = None
397
- return d
398
-
399
- @property
400
- def flow_threshold(self):
401
- return float(self.flow_threshold_box.text())
402
-
403
- @property
404
- def cellprob_threshold(self):
405
- return float(self.cellprob_threshold_box.text())
406
-
407
- @property
408
- def niter(self):
409
- num = int(self.niter_box.text())
410
- if num < 1:
411
- self.niter_box.setText('200')
412
- return 200
413
- else:
414
- return num
415
-
416
-
417
-
418
- class TrainWindow(QDialog):
419
-
420
- def __init__(self, parent, model_strings):
421
- super().__init__(parent)
422
- self.setGeometry(100, 100, 900, 550)
423
- self.setWindowTitle("train settings")
424
- self.win = QWidget(self)
425
- self.l0 = QGridLayout()
426
- self.win.setLayout(self.l0)
427
-
428
- yoff = 0
429
- qlabel = QLabel("train model w/ images + _seg.npy in current folder >>")
430
- qlabel.setFont(QtGui.QFont("Arial", 10, QtGui.QFont.Bold))
431
-
432
- qlabel.setAlignment(QtCore.Qt.AlignVCenter)
433
- self.l0.addWidget(qlabel, yoff, 0, 1, 2)
434
-
435
- # choose initial model
436
- yoff += 1
437
- self.ModelChoose = QComboBox()
438
- self.ModelChoose.addItems(model_strings)
439
- self.ModelChoose.setFixedWidth(150)
440
- self.ModelChoose.setCurrentIndex(parent.training_params["model_index"])
441
- self.l0.addWidget(self.ModelChoose, yoff, 1, 1, 1)
442
- qlabel = QLabel("initial model: ")
443
- qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
444
- self.l0.addWidget(qlabel, yoff, 0, 1, 1)
445
-
446
- # choose parameters
447
- labels = ["learning_rate", "weight_decay", "n_epochs", "model_name"]
448
- self.edits = []
449
- yoff += 1
450
- for i, label in enumerate(labels):
451
- qlabel = QLabel(label)
452
- qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
453
- self.l0.addWidget(qlabel, i + yoff, 0, 1, 1)
454
- self.edits.append(QLineEdit())
455
- self.edits[-1].setText(str(parent.training_params[label]))
456
- self.edits[-1].setFixedWidth(200)
457
- self.l0.addWidget(self.edits[-1], i + yoff, 1, 1, 1)
458
-
459
- yoff += len(labels)
460
-
461
- yoff += 1
462
- self.use_norm = QCheckBox(f"use restored/filtered image")
463
- self.use_norm.setChecked(True)
464
-
465
- yoff += 2
466
- qlabel = QLabel(
467
- "(to remove files, click cancel then remove \nfrom folder and reopen train window)"
468
- )
469
- self.l0.addWidget(qlabel, yoff, 0, 2, 4)
470
-
471
- # click button
472
- yoff += 3
473
- QBtn = QDialogButtonBox.Ok | QDialogButtonBox.Cancel
474
- self.buttonBox = QDialogButtonBox(QBtn)
475
- self.buttonBox.accepted.connect(lambda: self.accept(parent))
476
- self.buttonBox.rejected.connect(self.reject)
477
- self.l0.addWidget(self.buttonBox, yoff, 0, 1, 4)
478
-
479
- # list files in folder
480
- qlabel = QLabel("filenames")
481
- qlabel.setFont(QtGui.QFont("Arial", 8, QtGui.QFont.Bold))
482
- self.l0.addWidget(qlabel, 0, 4, 1, 1)
483
- qlabel = QLabel("# of masks")
484
- qlabel.setFont(QtGui.QFont("Arial", 8, QtGui.QFont.Bold))
485
- self.l0.addWidget(qlabel, 0, 5, 1, 1)
486
-
487
- for i in range(10):
488
- if i > len(parent.train_files) - 1:
489
- break
490
- elif i == 9 and len(parent.train_files) > 10:
491
- label = "..."
492
- nmasks = "..."
493
- else:
494
- label = os.path.split(parent.train_files[i])[-1]
495
- nmasks = str(parent.train_labels[i].max())
496
- qlabel = QLabel(label)
497
- self.l0.addWidget(qlabel, i + 1, 4, 1, 1)
498
- qlabel = QLabel(nmasks)
499
- qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
500
- self.l0.addWidget(qlabel, i + 1, 5, 1, 1)
501
-
502
- def accept(self, parent):
503
- # set training params
504
- parent.training_params = {
505
- "model_index": self.ModelChoose.currentIndex(),
506
- "learning_rate": float(self.edits[0].text()),
507
- "weight_decay": float(self.edits[1].text()),
508
- "n_epochs": int(self.edits[2].text()),
509
- "model_name": self.edits[3].text(),
510
- #"use_norm": True if self.use_norm.isChecked() else False,
511
- }
512
- self.done(1)
513
-
514
-
515
- class ExampleGUI(QDialog):
516
-
517
- def __init__(self, parent=None):
518
- super(ExampleGUI, self).__init__(parent)
519
- self.setGeometry(100, 100, 1300, 900)
520
- self.setWindowTitle("GUI layout")
521
- self.win = QWidget(self)
522
- layout = QGridLayout()
523
- self.win.setLayout(layout)
524
- guip_path = pathlib.Path.home().joinpath(".cellpose", "cellposeSAM_gui.png")
525
- guip_path = str(guip_path.resolve())
526
- pixmap = QPixmap(guip_path)
527
- label = QLabel(self)
528
- label.setPixmap(pixmap)
529
- pixmap.scaled
530
- layout.addWidget(label, 0, 0, 1, 1)
531
-
532
-
533
- class HelpWindow(QDialog):
534
-
535
- def __init__(self, parent=None):
536
- super(HelpWindow, self).__init__(parent)
537
- self.setGeometry(100, 50, 700, 1000)
538
- self.setWindowTitle("cellpose help")
539
- self.win = QWidget(self)
540
- layout = QGridLayout()
541
- self.win.setLayout(layout)
542
-
543
- text_file = pathlib.Path(__file__).parent.joinpath("guihelpwindowtext.html")
544
- with open(str(text_file.resolve()), "r") as f:
545
- text = f.read()
546
-
547
- label = QLabel(text)
548
- label.setFont(QtGui.QFont("Arial", 8))
549
- label.setWordWrap(True)
550
- layout.addWidget(label, 0, 0, 1, 1)
551
- self.show()
552
-
553
-
554
- class TrainHelpWindow(QDialog):
555
-
556
- def __init__(self, parent=None):
557
- super(TrainHelpWindow, self).__init__(parent)
558
- self.setGeometry(100, 50, 700, 300)
559
- self.setWindowTitle("training instructions")
560
- self.win = QWidget(self)
561
- layout = QGridLayout()
562
- self.win.setLayout(layout)
563
-
564
- text_file = pathlib.Path(__file__).parent.joinpath(
565
- "guitrainhelpwindowtext.html")
566
- with open(str(text_file.resolve()), "r") as f:
567
- text = f.read()
568
-
569
- label = QLabel(text)
570
- label.setFont(QtGui.QFont("Arial", 8))
571
- label.setWordWrap(True)
572
- layout.addWidget(label, 0, 0, 1, 1)
573
- self.show()
574
-
575
-
576
- class ViewBoxNoRightDrag(pg.ViewBox):
577
-
578
- def __init__(self, parent=None, border=None, lockAspect=False, enableMouse=True,
579
- invertY=False, enableMenu=True, name=None, invertX=False):
580
- pg.ViewBox.__init__(self, None, border, lockAspect, enableMouse, invertY,
581
- enableMenu, name, invertX)
582
- self.parent = parent
583
- self.axHistoryPointer = -1
584
-
585
- def keyPressEvent(self, ev):
586
- """
587
- This routine should capture key presses in the current view box.
588
- The following events are implemented:
589
- +/= : moves forward in the zooming stack (if it exists)
590
- - : moves backward in the zooming stack (if it exists)
591
-
592
- """
593
- ev.accept()
594
- if ev.text() == "-":
595
- self.scaleBy([1.1, 1.1])
596
- elif ev.text() in ["+", "="]:
597
- self.scaleBy([0.9, 0.9])
598
- else:
599
- ev.ignore()
600
-
601
-
602
- class ImageDraw(pg.ImageItem):
603
- """
604
- **Bases:** :class:`GraphicsObject <pyqtgraph.GraphicsObject>`
605
- GraphicsObject displaying an image. Optimized for rapid update (ie video display).
606
- This item displays either a 2D numpy array (height, width) or
607
- a 3D array (height, width, RGBa). This array is optionally scaled (see
608
- :func:`setLevels <pyqtgraph.ImageItem.setLevels>`) and/or colored
609
- with a lookup table (see :func:`setLookupTable <pyqtgraph.ImageItem.setLookupTable>`)
610
- before being displayed.
611
- ImageItem is frequently used in conjunction with
612
- :class:`HistogramLUTItem <pyqtgraph.HistogramLUTItem>` or
613
- :class:`HistogramLUTWidget <pyqtgraph.HistogramLUTWidget>` to provide a GUI
614
- for controlling the levels and lookup table used to display the image.
615
- """
616
-
617
- sigImageChanged = QtCore.Signal()
618
-
619
- def __init__(self, image=None, viewbox=None, parent=None, **kargs):
620
- super(ImageDraw, self).__init__()
621
- self.levels = np.array([0, 255])
622
- self.lut = None
623
- self.autoDownsample = False
624
- self.axisOrder = "row-major"
625
- self.removable = False
626
-
627
- self.parent = parent
628
- self.setDrawKernel(kernel_size=self.parent.brush_size)
629
- self.parent.current_stroke = []
630
- self.parent.in_stroke = False
631
-
632
- def mouseClickEvent(self, ev):
633
- if (self.parent.masksOn or
634
- self.parent.outlinesOn) and not self.parent.removing_region:
635
- is_right_click = ev.button() == QtCore.Qt.RightButton
636
- if self.parent.loaded \
637
- and (is_right_click or ev.modifiers() & QtCore.Qt.ShiftModifier and not ev.double())\
638
- and not self.parent.deleting_multiple:
639
- if not self.parent.in_stroke:
640
- ev.accept()
641
- self.create_start(ev.pos())
642
- self.parent.stroke_appended = False
643
- self.parent.in_stroke = True
644
- self.drawAt(ev.pos(), ev)
645
- else:
646
- ev.accept()
647
- self.end_stroke()
648
- self.parent.in_stroke = False
649
- elif not self.parent.in_stroke:
650
- y, x = int(ev.pos().y()), int(ev.pos().x())
651
- if y >= 0 and y < self.parent.Ly and x >= 0 and x < self.parent.Lx:
652
- if ev.button() == QtCore.Qt.LeftButton and not ev.double():
653
- idx = self.parent.cellpix[self.parent.currentZ][y, x]
654
- if idx > 0:
655
- if ev.modifiers() & QtCore.Qt.ControlModifier:
656
- # delete mask selected
657
- self.parent.remove_cell(idx)
658
- elif ev.modifiers() & QtCore.Qt.AltModifier:
659
- self.parent.merge_cells(idx)
660
- elif self.parent.masksOn and not self.parent.deleting_multiple:
661
- self.parent.unselect_cell()
662
- self.parent.select_cell(idx)
663
- elif self.parent.deleting_multiple:
664
- if idx in self.parent.removing_cells_list:
665
- self.parent.unselect_cell_multi(idx)
666
- self.parent.removing_cells_list.remove(idx)
667
- else:
668
- self.parent.select_cell_multi(idx)
669
- self.parent.removing_cells_list.append(idx)
670
-
671
- elif self.parent.masksOn and not self.parent.deleting_multiple:
672
- self.parent.unselect_cell()
673
-
674
- def mouseDragEvent(self, ev):
675
- ev.ignore()
676
- return
677
-
678
- def hoverEvent(self, ev):
679
- if self.parent.in_stroke:
680
- if self.parent.in_stroke:
681
- # continue stroke if not at start
682
- self.drawAt(ev.pos())
683
- if self.is_at_start(ev.pos()):
684
- self.end_stroke()
685
- else:
686
- ev.acceptClicks(QtCore.Qt.RightButton)
687
-
688
- def create_start(self, pos):
689
- self.scatter = pg.ScatterPlotItem([pos.x()], [pos.y()], pxMode=False,
690
- pen=pg.mkPen(color=(255, 0, 0),
691
- width=self.parent.brush_size),
692
- size=max(3 * 2,
693
- self.parent.brush_size * 1.8 * 2),
694
- brush=None)
695
- self.parent.p0.addItem(self.scatter)
696
-
697
- def is_at_start(self, pos):
698
- thresh_out = max(6, self.parent.brush_size * 3)
699
- thresh_in = max(3, self.parent.brush_size * 1.8)
700
- # first check if you ever left the start
701
- if len(self.parent.current_stroke) > 3:
702
- stroke = np.array(self.parent.current_stroke)
703
- dist = (((stroke[1:, 1:] -
704
- stroke[:1, 1:][np.newaxis, :, :])**2).sum(axis=-1))**0.5
705
- dist = dist.flatten()
706
- has_left = (dist > thresh_out).nonzero()[0]
707
- if len(has_left) > 0:
708
- first_left = np.sort(has_left)[0]
709
- has_returned = (dist[max(4, first_left + 1):] < thresh_in).sum()
710
- if has_returned > 0:
711
- return True
712
- else:
713
- return False
714
- else:
715
- return False
716
-
717
- def end_stroke(self):
718
- self.parent.p0.removeItem(self.scatter)
719
- if not self.parent.stroke_appended:
720
- self.parent.strokes.append(self.parent.current_stroke)
721
- self.parent.stroke_appended = True
722
- self.parent.current_stroke = np.array(self.parent.current_stroke)
723
- ioutline = self.parent.current_stroke[:, 3] == 1
724
- self.parent.current_point_set.append(
725
- list(self.parent.current_stroke[ioutline]))
726
- self.parent.current_stroke = []
727
- if self.parent.autosave:
728
- self.parent.add_set()
729
- if len(self.parent.current_point_set) and len(
730
- self.parent.current_point_set[0]) > 0 and self.parent.autosave:
731
- self.parent.add_set()
732
- self.parent.in_stroke = False
733
-
734
- def tabletEvent(self, ev):
735
- pass
736
-
737
- def drawAt(self, pos, ev=None):
738
- mask = self.strokemask
739
- stroke = self.parent.current_stroke
740
- pos = [int(pos.y()), int(pos.x())]
741
- dk = self.drawKernel
742
- kc = self.drawKernelCenter
743
- sx = [0, dk.shape[0]]
744
- sy = [0, dk.shape[1]]
745
- tx = [pos[0] - kc[0], pos[0] - kc[0] + dk.shape[0]]
746
- ty = [pos[1] - kc[1], pos[1] - kc[1] + dk.shape[1]]
747
- kcent = kc.copy()
748
- if tx[0] <= 0:
749
- sx[0] = 0
750
- sx[1] = kc[0] + 1
751
- tx = sx
752
- kcent[0] = 0
753
- if ty[0] <= 0:
754
- sy[0] = 0
755
- sy[1] = kc[1] + 1
756
- ty = sy
757
- kcent[1] = 0
758
- if tx[1] >= self.parent.Ly - 1:
759
- sx[0] = dk.shape[0] - kc[0] - 1
760
- sx[1] = dk.shape[0]
761
- tx[0] = self.parent.Ly - kc[0] - 1
762
- tx[1] = self.parent.Ly
763
- kcent[0] = tx[1] - tx[0] - 1
764
- if ty[1] >= self.parent.Lx - 1:
765
- sy[0] = dk.shape[1] - kc[1] - 1
766
- sy[1] = dk.shape[1]
767
- ty[0] = self.parent.Lx - kc[1] - 1
768
- ty[1] = self.parent.Lx
769
- kcent[1] = ty[1] - ty[0] - 1
770
-
771
- ts = (slice(tx[0], tx[1]), slice(ty[0], ty[1]))
772
- ss = (slice(sx[0], sx[1]), slice(sy[0], sy[1]))
773
- self.image[ts] = mask[ss]
774
-
775
- for ky, y in enumerate(np.arange(ty[0], ty[1], 1, int)):
776
- for kx, x in enumerate(np.arange(tx[0], tx[1], 1, int)):
777
- iscent = np.logical_and(kx == kcent[0], ky == kcent[1])
778
- stroke.append([self.parent.currentZ, x, y, iscent])
779
- self.updateImage()
780
-
781
- def setDrawKernel(self, kernel_size=3):
782
- bs = kernel_size
783
- kernel = np.ones((bs, bs), np.uint8)
784
- self.drawKernel = kernel
785
- self.drawKernelCenter = [
786
- int(np.floor(kernel.shape[0] / 2)),
787
- int(np.floor(kernel.shape[1] / 2))
788
- ]
789
- onmask = 255 * kernel[:, :, np.newaxis]
790
- offmask = np.zeros((bs, bs, 1))
791
- opamask = 100 * kernel[:, :, np.newaxis]
792
- self.redmask = np.concatenate((onmask, offmask, offmask, onmask), axis=-1)
793
- self.strokemask = np.concatenate((onmask, offmask, onmask, opamask), axis=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/seg_post_model/cellpose/gui/guitrainhelpwindowtext.html DELETED
@@ -1,25 +0,0 @@
1
- <qt>
2
- Check out this <a href="https://youtu.be/3Y1VKcxjNy4">video</a> to learn the process.
3
- <ol>
4
- <li>Drag and drop an image from a folder of images with a similar style (like similar cell types).</li>
5
- <li>Run the built-in models on one of the images using the "model zoo" and find the one that works best for your
6
- data. Make sure that if you have a nuclear channel you have selected it for CHAN2.
7
- </li>
8
- <li>Fix the labelling by drawing new ROIs (right-click) and deleting incorrect ones (CTRL+click). The GUI
9
- autosaves any manual changes (but does not autosave after running the model, for that click CTRL+S). The
10
- segmentation is saved in a "_seg.npy" file.
11
- </li>
12
- <li> Go to the "Models" menu in the File bar at the top and click "Train new model..." or use shortcut CTRL+T.
13
- </li>
14
- <li> Choose the pretrained model to start the training from (the model you used in #2), and type in the model
15
- name that you want to use. The other parameters should work well in general for most data types. Then click
16
- OK.
17
- </li>
18
- <li> The model will train (much faster if you have a GPU) and then auto-run on the next image in the folder.
19
- Next you can repeat #3-#5 as many times as is necessary.
20
- </li>
21
- <li> The trained model is available to use in the future in the GUI in the "custom model" section and is saved
22
- in your image folder.
23
- </li>
24
- </ol>
25
- </qt>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/seg_post_model/cellpose/gui/io.py DELETED
@@ -1,634 +0,0 @@
1
- """
2
- Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
- """
4
- import os, gc
5
- import numpy as np
6
- import cv2
7
- import fastremap
8
-
9
- from ..io import imread, imread_2D, imread_3D, imsave, outlines_to_text, add_model, remove_model, save_rois
10
- from ..models import normalize_default, MODEL_DIR, MODEL_LIST_PATH, get_user_models
11
- from ..utils import masks_to_outlines, outlines_list
12
-
13
- try:
14
- import qtpy
15
- from qtpy.QtWidgets import QFileDialog
16
- GUI = True
17
- except:
18
- GUI = False
19
-
20
- try:
21
- import matplotlib.pyplot as plt
22
- MATPLOTLIB = True
23
- except:
24
- MATPLOTLIB = False
25
-
26
-
27
- def _init_model_list(parent):
28
- MODEL_DIR.mkdir(parents=True, exist_ok=True)
29
- parent.model_list_path = MODEL_LIST_PATH
30
- parent.model_strings = get_user_models()
31
-
32
-
33
- def _add_model(parent, filename=None, load_model=True):
34
- if filename is None:
35
- name = QFileDialog.getOpenFileName(parent, "Add model to GUI")
36
- filename = name[0]
37
- add_model(filename)
38
- fname = os.path.split(filename)[-1]
39
- parent.ModelChooseC.addItems([fname])
40
- parent.model_strings.append(fname)
41
-
42
- for ind, model_string in enumerate(parent.model_strings[:-1]):
43
- if model_string == fname:
44
- _remove_model(parent, ind=ind + 1, verbose=False)
45
-
46
- parent.ModelChooseC.setCurrentIndex(len(parent.model_strings))
47
- if load_model:
48
- parent.model_choose(custom=True)
49
-
50
-
51
- def _remove_model(parent, ind=None, verbose=True):
52
- if ind is None:
53
- ind = parent.ModelChooseC.currentIndex()
54
- if ind > 0:
55
- ind -= 1
56
- parent.ModelChooseC.removeItem(ind + 1)
57
- del parent.model_strings[ind]
58
- # remove model from txt path
59
- modelstr = parent.ModelChooseC.currentText()
60
- remove_model(modelstr)
61
- if len(parent.model_strings) > 0:
62
- parent.ModelChooseC.setCurrentIndex(len(parent.model_strings))
63
- else:
64
- parent.ModelChooseC.setCurrentIndex(0)
65
- else:
66
- print("ERROR: no model selected to delete")
67
-
68
-
69
- def _get_train_set(image_names):
70
- """ get training data and labels for images in current folder image_names"""
71
- train_data, train_labels, train_files = [], [], []
72
- restore = None
73
- normalize_params = normalize_default
74
- for image_name_full in image_names:
75
- image_name = os.path.splitext(image_name_full)[0]
76
- label_name = None
77
- if os.path.exists(image_name + "_seg.npy"):
78
- dat = np.load(image_name + "_seg.npy", allow_pickle=True).item()
79
- masks = dat["masks"].squeeze()
80
- if masks.ndim == 2:
81
- fastremap.renumber(masks, in_place=True)
82
- label_name = image_name + "_seg.npy"
83
- else:
84
- print(f"GUI_INFO: _seg.npy found for {image_name} but masks.ndim!=2")
85
- if "img_restore" in dat:
86
- data = dat["img_restore"].squeeze()
87
- restore = dat["restore"]
88
- else:
89
- data = imread(image_name_full)
90
- normalize_params = dat[
91
- "normalize_params"] if "normalize_params" in dat else normalize_default
92
- if label_name is not None:
93
- train_files.append(image_name_full)
94
- train_data.append(data)
95
- train_labels.append(masks)
96
- if restore:
97
- print(f"GUI_INFO: using {restore} images (dat['img_restore'])")
98
- return train_data, train_labels, train_files, restore, normalize_params
99
-
100
-
101
- def _load_image(parent, filename=None, load_seg=True, load_3D=False):
102
- """ load image with filename; if None, open QFileDialog
103
- if image is grey change view to default to grey scale
104
- """
105
-
106
- if parent.load_3D:
107
- load_3D = True
108
-
109
- if filename is None:
110
- name = QFileDialog.getOpenFileName(parent, "Load image")
111
- filename = name[0]
112
- if filename == "":
113
- return
114
- manual_file = os.path.splitext(filename)[0] + "_seg.npy"
115
- load_mask = False
116
- if load_seg:
117
- if os.path.isfile(manual_file) and not parent.autoloadMasks.isChecked():
118
- if filename is not None:
119
- image = (imread_2D(filename) if not load_3D else
120
- imread_3D(filename))
121
- else:
122
- image = None
123
- _load_seg(parent, manual_file, image=image, image_file=filename,
124
- load_3D=load_3D)
125
- return
126
- elif parent.autoloadMasks.isChecked():
127
- mask_file = os.path.splitext(filename)[0] + "_masks" + os.path.splitext(
128
- filename)[-1]
129
- mask_file = os.path.splitext(filename)[
130
- 0] + "_masks.tif" if not os.path.isfile(mask_file) else mask_file
131
- load_mask = True if os.path.isfile(mask_file) else False
132
- try:
133
- print(f"GUI_INFO: loading image: {filename}")
134
- if not load_3D:
135
- image = imread_2D(filename)
136
- else:
137
- image = imread_3D(filename)
138
- parent.loaded = True
139
- except Exception as e:
140
- print("ERROR: images not compatible")
141
- print(f"ERROR: {e}")
142
-
143
- if parent.loaded:
144
- parent.reset()
145
- parent.filename = filename
146
- filename = os.path.split(parent.filename)[-1]
147
- _initialize_images(parent, image, load_3D=load_3D)
148
- parent.loaded = True
149
- parent.enable_buttons()
150
- if load_mask:
151
- _load_masks(parent, filename=mask_file)
152
-
153
- # check if gray and adjust viewer:
154
- if len(np.unique(image[..., 1:])) == 1:
155
- parent.color = 4
156
- parent.RGBDropDown.setCurrentIndex(4) # gray
157
- parent.update_plot()
158
-
159
-
160
- def _initialize_images(parent, image, load_3D=False):
161
- """ format image for GUI
162
-
163
- assumes image is Z x W x H x C
164
-
165
- """
166
- load_3D = parent.load_3D if load_3D is False else load_3D
167
-
168
- parent.stack = image
169
- print(f"GUI_INFO: image shape: {image.shape}")
170
- if load_3D:
171
- parent.NZ = len(parent.stack)
172
- parent.scroll.setMaximum(parent.NZ - 1)
173
- else:
174
- parent.NZ = 1
175
- parent.stack = parent.stack[np.newaxis, ...]
176
-
177
- img_min = image.min()
178
- img_max = image.max()
179
- parent.stack = parent.stack.astype(np.float32)
180
- parent.stack -= img_min
181
- if img_max > img_min + 1e-3:
182
- parent.stack /= (img_max - img_min)
183
- parent.stack *= 255
184
-
185
- if load_3D:
186
- print("GUI_INFO: converted to float and normalized values to 0.0->255.0")
187
-
188
- del image
189
- gc.collect()
190
-
191
- parent.imask = 0
192
- parent.Ly, parent.Lx = parent.stack.shape[-3:-1]
193
- parent.Ly0, parent.Lx0 = parent.stack.shape[-3:-1]
194
- parent.layerz = 255 * np.ones((parent.Ly, parent.Lx, 4), "uint8")
195
- if hasattr(parent, "stack_filtered"):
196
- parent.Lyr, parent.Lxr = parent.stack_filtered.shape[-3:-1]
197
- elif parent.restore and "upsample" in parent.restore:
198
- parent.Lyr, parent.Lxr = int(parent.Ly * parent.ratio), int(parent.Lx *
199
- parent.ratio)
200
- else:
201
- parent.Lyr, parent.Lxr = parent.Ly, parent.Lx
202
- parent.clear_all()
203
-
204
- if not hasattr(parent, "stack_filtered") and parent.restore:
205
- print("GUI_INFO: no 'img_restore' found, applying current settings")
206
- parent.compute_restore()
207
-
208
- if parent.autobtn.isChecked():
209
- if parent.restore is None or parent.restore != "filter":
210
- print(
211
- "GUI_INFO: normalization checked: computing saturation levels (and optionally filtered image)"
212
- )
213
- parent.compute_saturation()
214
- # elif len(parent.saturation) != parent.NZ:
215
- # parent.saturation = []
216
- # for r in range(3):
217
- # parent.saturation.append([])
218
- # for n in range(parent.NZ):
219
- # parent.saturation[-1].append([0, 255])
220
- # parent.sliders[r].setValue([0, 255])
221
- parent.compute_scale()
222
- parent.track_changes = []
223
-
224
- if load_3D:
225
- parent.currentZ = int(np.floor(parent.NZ / 2))
226
- parent.scroll.setValue(parent.currentZ)
227
- parent.zpos.setText(str(parent.currentZ))
228
- else:
229
- parent.currentZ = 0
230
-
231
-
232
- def _load_seg(parent, filename=None, image=None, image_file=None, load_3D=False):
233
- """ load *_seg.npy with filename; if None, open QFileDialog """
234
- if filename is None:
235
- name = QFileDialog.getOpenFileName(parent, "Load labelled data", filter="*.npy")
236
- filename = name[0]
237
- try:
238
- dat = np.load(filename, allow_pickle=True).item()
239
- # check if there are keys in filename
240
- dat["outlines"]
241
- parent.loaded = True
242
- except:
243
- parent.loaded = False
244
- print("ERROR: not NPY")
245
- return
246
-
247
- parent.reset()
248
- if image is None:
249
- found_image = False
250
- if "filename" in dat:
251
- parent.filename = dat["filename"]
252
- if os.path.isfile(parent.filename):
253
- parent.filename = dat["filename"]
254
- found_image = True
255
- else:
256
- imgname = os.path.split(parent.filename)[1]
257
- root = os.path.split(filename)[0]
258
- parent.filename = root + "/" + imgname
259
- if os.path.isfile(parent.filename):
260
- found_image = True
261
- if found_image:
262
- try:
263
- print(parent.filename)
264
- image = (imread_2D(parent.filename) if not load_3D else
265
- imread_3D(parent.filename))
266
- except:
267
- parent.loaded = False
268
- found_image = False
269
- print("ERROR: cannot find image file, loading from npy")
270
- if not found_image:
271
- parent.filename = filename[:-8]
272
- print(parent.filename)
273
- if "img" in dat:
274
- image = dat["img"]
275
- else:
276
- print("ERROR: no image file found and no image in npy")
277
- return
278
- else:
279
- parent.filename = image_file
280
-
281
- parent.restore = None
282
- parent.ratio = 1.
283
-
284
- if "normalize_params" in dat:
285
- parent.set_normalize_params(dat["normalize_params"])
286
-
287
- _initialize_images(parent, image, load_3D=load_3D)
288
- print(parent.stack.shape)
289
-
290
- if "outlines" in dat:
291
- if isinstance(dat["outlines"], list):
292
- # old way of saving files
293
- dat["outlines"] = dat["outlines"][::-1]
294
- for k, outline in enumerate(dat["outlines"]):
295
- if "colors" in dat:
296
- color = dat["colors"][k]
297
- else:
298
- col_rand = np.random.randint(1000)
299
- color = parent.colormap[col_rand, :3]
300
- median = parent.add_mask(points=outline, color=color)
301
- if median is not None:
302
- parent.cellcolors = np.append(parent.cellcolors,
303
- color[np.newaxis, :], axis=0)
304
- parent.ncells += 1
305
- else:
306
- if dat["masks"].min() == -1:
307
- dat["masks"] += 1
308
- dat["outlines"] += 1
309
- parent.ncells.set(dat["masks"].max())
310
- if "colors" in dat and len(dat["colors"]) == dat["masks"].max():
311
- colors = dat["colors"]
312
- else:
313
- colors = parent.colormap[:parent.ncells.get(), :3]
314
-
315
- _masks_to_gui(parent, dat["masks"], outlines=dat["outlines"], colors=colors)
316
-
317
- parent.draw_layer()
318
-
319
- if "manual_changes" in dat:
320
- parent.track_changes = dat["manual_changes"]
321
- print("GUI_INFO: loaded in previous changes")
322
- if "zdraw" in dat:
323
- parent.zdraw = dat["zdraw"]
324
- else:
325
- parent.zdraw = [None for n in range(parent.ncells.get())]
326
- parent.loaded = True
327
- else:
328
- parent.clear_all()
329
-
330
- parent.ismanual = np.zeros(parent.ncells.get(), bool)
331
- if "ismanual" in dat:
332
- if len(dat["ismanual"]) == parent.ncells:
333
- parent.ismanual = dat["ismanual"]
334
-
335
- if "current_channel" in dat:
336
- parent.color = (dat["current_channel"] + 2) % 5
337
- parent.RGBDropDown.setCurrentIndex(parent.color)
338
-
339
- if "flows" in dat:
340
- parent.flows = dat["flows"]
341
- try:
342
- if parent.flows[0].shape[-3] != dat["masks"].shape[-2]:
343
- Ly, Lx = dat["masks"].shape[-2:]
344
- for i in range(len(parent.flows)):
345
- parent.flows[i] = cv2.resize(
346
- parent.flows[i].squeeze(), (Lx, Ly),
347
- interpolation=cv2.INTER_NEAREST)[np.newaxis, ...]
348
- if parent.NZ == 1:
349
- parent.recompute_masks = True
350
- else:
351
- parent.recompute_masks = False
352
-
353
- except:
354
- try:
355
- if len(parent.flows[0]) > 0:
356
- parent.flows = parent.flows[0]
357
- except:
358
- parent.flows = [[], [], [], [], [[]]]
359
- parent.recompute_masks = False
360
-
361
- parent.enable_buttons()
362
- parent.update_layer()
363
- del dat
364
- gc.collect()
365
-
366
-
367
- def _load_masks(parent, filename=None):
368
- """ load zeros-based masks (0=no cell, 1=cell 1, ...) """
369
- if filename is None:
370
- name = QFileDialog.getOpenFileName(parent, "Load masks (PNG or TIFF)")
371
- filename = name[0]
372
- print(f"GUI_INFO: loading masks: {filename}")
373
- masks = imread(filename)
374
- outlines = None
375
- if masks.ndim > 3:
376
- # Z x nchannels x Ly x Lx
377
- if masks.shape[-1] > 5:
378
- parent.flows = list(np.transpose(masks[:, :, :, 2:], (3, 0, 1, 2)))
379
- outlines = masks[..., 1]
380
- masks = masks[..., 0]
381
- else:
382
- parent.flows = list(np.transpose(masks[:, :, :, 1:], (3, 0, 1, 2)))
383
- masks = masks[..., 0]
384
- elif masks.ndim == 3:
385
- if masks.shape[-1] < 5:
386
- masks = masks[np.newaxis, :, :, 0]
387
- elif masks.ndim < 3:
388
- masks = masks[np.newaxis, :, :]
389
- # masks should be Z x Ly x Lx
390
- if masks.shape[0] != parent.NZ:
391
- print("ERROR: masks are not same depth (number of planes) as image stack")
392
- return
393
-
394
- _masks_to_gui(parent, masks, outlines)
395
- if parent.ncells > 0:
396
- parent.draw_layer()
397
- parent.toggle_mask_ops()
398
- del masks
399
- gc.collect()
400
- parent.update_layer()
401
- parent.update_plot()
402
-
403
-
404
- def _masks_to_gui(parent, masks, outlines=None, colors=None):
405
- """ masks loaded into GUI """
406
- # get unique values
407
- shape = masks.shape
408
- if len(fastremap.unique(masks)) != masks.max() + 1:
409
- print("GUI_INFO: renumbering masks")
410
- fastremap.renumber(masks, in_place=True)
411
- outlines = None
412
- masks = masks.reshape(shape)
413
- if masks.ndim == 2:
414
- outlines = None
415
- masks = masks.astype(np.uint16) if masks.max() < 2**16 - 1 else masks.astype(
416
- np.uint32)
417
- if parent.restore and "upsample" in parent.restore:
418
- parent.cellpix_resize = masks.copy()
419
- parent.cellpix = parent.cellpix_resize.copy()
420
- parent.cellpix_orig = cv2.resize(
421
- masks.squeeze(), (parent.Lx0, parent.Ly0),
422
- interpolation=cv2.INTER_NEAREST)[np.newaxis, :, :]
423
- parent.resize = True
424
- else:
425
- parent.cellpix = masks
426
- if parent.cellpix.ndim == 2:
427
- parent.cellpix = parent.cellpix[np.newaxis, :, :]
428
- if parent.restore and "upsample" in parent.restore:
429
- if parent.cellpix_resize.ndim == 2:
430
- parent.cellpix_resize = parent.cellpix_resize[np.newaxis, :, :]
431
- if parent.cellpix_orig.ndim == 2:
432
- parent.cellpix_orig = parent.cellpix_orig[np.newaxis, :, :]
433
-
434
- print(f"GUI_INFO: {masks.max()} masks found")
435
-
436
- # get outlines
437
- if outlines is None: # parent.outlinesOn
438
- parent.outpix = np.zeros_like(parent.cellpix)
439
- if parent.restore and "upsample" in parent.restore:
440
- parent.outpix_orig = np.zeros_like(parent.cellpix_orig)
441
- for z in range(parent.NZ):
442
- outlines = masks_to_outlines(parent.cellpix[z])
443
- parent.outpix[z] = outlines * parent.cellpix[z]
444
- if parent.restore and "upsample" in parent.restore:
445
- outlines = masks_to_outlines(parent.cellpix_orig[z])
446
- parent.outpix_orig[z] = outlines * parent.cellpix_orig[z]
447
- if z % 50 == 0 and parent.NZ > 1:
448
- print("GUI_INFO: plane %d outlines processed" % z)
449
- if parent.restore and "upsample" in parent.restore:
450
- parent.outpix_resize = parent.outpix.copy()
451
- else:
452
- parent.outpix = outlines
453
- if parent.restore and "upsample" in parent.restore:
454
- parent.outpix_resize = parent.outpix.copy()
455
- parent.outpix_orig = np.zeros_like(parent.cellpix_orig)
456
- for z in range(parent.NZ):
457
- outlines = masks_to_outlines(parent.cellpix_orig[z])
458
- parent.outpix_orig[z] = outlines * parent.cellpix_orig[z]
459
- if z % 50 == 0 and parent.NZ > 1:
460
- print("GUI_INFO: plane %d outlines processed" % z)
461
-
462
- if parent.outpix.ndim == 2:
463
- parent.outpix = parent.outpix[np.newaxis, :, :]
464
- if parent.restore and "upsample" in parent.restore:
465
- if parent.outpix_resize.ndim == 2:
466
- parent.outpix_resize = parent.outpix_resize[np.newaxis, :, :]
467
- if parent.outpix_orig.ndim == 2:
468
- parent.outpix_orig = parent.outpix_orig[np.newaxis, :, :]
469
-
470
- parent.ncells.set(parent.cellpix.max())
471
- colors = parent.colormap[:parent.ncells.get(), :3] if colors is None else colors
472
- print("GUI_INFO: creating cellcolors and drawing masks")
473
- parent.cellcolors = np.concatenate((np.array([[255, 255, 255]]), colors),
474
- axis=0).astype(np.uint8)
475
- if parent.ncells > 0:
476
- parent.draw_layer()
477
- parent.toggle_mask_ops()
478
- parent.ismanual = np.zeros(parent.ncells.get(), bool)
479
- parent.zdraw = list(-1 * np.ones(parent.ncells.get(), np.int16))
480
-
481
- if hasattr(parent, "stack_filtered"):
482
- parent.ViewDropDown.setCurrentIndex(parent.ViewDropDown.count() - 1)
483
- print("set denoised/filtered view")
484
- else:
485
- parent.ViewDropDown.setCurrentIndex(0)
486
-
487
-
488
- def _save_png(parent):
489
- """ save masks to png or tiff (if 3D) """
490
- filename = parent.filename
491
- base = os.path.splitext(filename)[0]
492
- if parent.NZ == 1:
493
- if parent.cellpix[0].max() > 65534:
494
- print("GUI_INFO: saving 2D masks to tif (too many masks for PNG)")
495
- imsave(base + "_cp_masks.tif", parent.cellpix[0])
496
- else:
497
- print("GUI_INFO: saving 2D masks to png")
498
- imsave(base + "_cp_masks.png", parent.cellpix[0].astype(np.uint16))
499
- else:
500
- print("GUI_INFO: saving 3D masks to tiff")
501
- imsave(base + "_cp_masks.tif", parent.cellpix)
502
-
503
-
504
- def _save_flows(parent):
505
- """ save flows and cellprob to tiff """
506
- filename = parent.filename
507
- base = os.path.splitext(filename)[0]
508
- print("GUI_INFO: saving flows and cellprob to tiff")
509
- if len(parent.flows) > 0:
510
- imsave(base + "_cp_cellprob.tif", parent.flows[1])
511
- for i in range(3):
512
- imsave(base + f"_cp_flows_{i}.tif", parent.flows[0][..., i])
513
- if len(parent.flows) > 2:
514
- imsave(base + "_cp_flows.tif", parent.flows[2])
515
- print("GUI_INFO: saved flows and cellprob")
516
- else:
517
- print("ERROR: no flows or cellprob found")
518
-
519
-
520
- def _save_rois(parent):
521
- """ save masks as rois in .zip file for ImageJ """
522
- filename = parent.filename
523
- if parent.NZ == 1:
524
- print(
525
- f"GUI_INFO: saving {parent.cellpix[0].max()} ImageJ ROIs to .zip archive.")
526
- save_rois(parent.cellpix[0], parent.filename)
527
- else:
528
- print("ERROR: cannot save 3D outlines")
529
-
530
-
531
- def _save_outlines(parent):
532
- filename = parent.filename
533
- base = os.path.splitext(filename)[0]
534
- if parent.NZ == 1:
535
- print(
536
- "GUI_INFO: saving 2D outlines to text file, see docs for info to load into ImageJ"
537
- )
538
- outlines = outlines_list(parent.cellpix[0])
539
- outlines_to_text(base, outlines)
540
- else:
541
- print("ERROR: cannot save 3D outlines")
542
-
543
-
544
- def _save_sets_with_check(parent):
545
- """ Save masks and update *_seg.npy file. Use this function when saving should be optional
546
- based on the disableAutosave checkbox. Otherwise, use _save_sets """
547
- if not parent.disableAutosave.isChecked():
548
- _save_sets(parent)
549
-
550
-
551
- def _save_sets(parent):
552
- """ save masks to *_seg.npy. This function should be used when saving
553
- is forced, e.g. when clicking the save button. Otherwise, use _save_sets_with_check
554
- """
555
- filename = parent.filename
556
- base = os.path.splitext(filename)[0]
557
- flow_threshold = parent.segmentation_settings.flow_threshold
558
- cellprob_threshold = parent.segmentation_settings.cellprob_threshold
559
-
560
- if parent.NZ > 1:
561
- dat = {
562
- "outlines":
563
- parent.outpix,
564
- "colors":
565
- parent.cellcolors[1:],
566
- "masks":
567
- parent.cellpix,
568
- "current_channel": (parent.color - 2) % 5,
569
- "filename":
570
- parent.filename,
571
- "flows":
572
- parent.flows,
573
- "zdraw":
574
- parent.zdraw,
575
- "model_path":
576
- parent.current_model_path
577
- if hasattr(parent, "current_model_path") else 0,
578
- "flow_threshold":
579
- flow_threshold,
580
- "cellprob_threshold":
581
- cellprob_threshold,
582
- "normalize_params":
583
- parent.get_normalize_params(),
584
- "restore":
585
- parent.restore,
586
- "ratio":
587
- parent.ratio,
588
- "diameter":
589
- parent.segmentation_settings.diameter
590
- }
591
- if parent.restore is not None:
592
- dat["img_restore"] = parent.stack_filtered
593
- else:
594
- dat = {
595
- "outlines":
596
- parent.outpix.squeeze() if parent.restore is None or
597
- not "upsample" in parent.restore else parent.outpix_resize.squeeze(),
598
- "colors":
599
- parent.cellcolors[1:],
600
- "masks":
601
- parent.cellpix.squeeze() if parent.restore is None or
602
- not "upsample" in parent.restore else parent.cellpix_resize.squeeze(),
603
- "filename":
604
- parent.filename,
605
- "flows":
606
- parent.flows,
607
- "ismanual":
608
- parent.ismanual,
609
- "manual_changes":
610
- parent.track_changes,
611
- "model_path":
612
- parent.current_model_path
613
- if hasattr(parent, "current_model_path") else 0,
614
- "flow_threshold":
615
- flow_threshold,
616
- "cellprob_threshold":
617
- cellprob_threshold,
618
- "normalize_params":
619
- parent.get_normalize_params(),
620
- "restore":
621
- parent.restore,
622
- "ratio":
623
- parent.ratio,
624
- "diameter":
625
- parent.segmentation_settings.diameter
626
- }
627
- if parent.restore is not None:
628
- dat["img_restore"] = parent.stack_filtered
629
- try:
630
- np.save(base + "_seg.npy", dat)
631
- print("GUI_INFO: %d ROIs saved to %s" % (parent.ncells.get(), base + "_seg.npy"))
632
- except Exception as e:
633
- print(f"ERROR: {e}")
634
- del dat
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/seg_post_model/cellpose/gui/make_train.py DELETED
@@ -1,107 +0,0 @@
1
- import os, argparse
2
- import numpy as np
3
- from cellpose import io, transforms
4
-
5
-
6
- def main():
7
- parser = argparse.ArgumentParser(description='Make slices of XYZ image data for training. Assumes image is ZXYC unless specified otherwise using --channel_axis and --z_axis')
8
-
9
- input_img_args = parser.add_argument_group("input image arguments")
10
- input_img_args.add_argument('--dir', default=[], type=str,
11
- help='folder containing data to run or train on.')
12
- input_img_args.add_argument(
13
- '--image_path', default=[], type=str, help=
14
- 'if given and --dir not given, run on single image instead of folder (cannot train with this option)'
15
- )
16
- input_img_args.add_argument(
17
- '--look_one_level_down', action='store_true',
18
- help='run processing on all subdirectories of current folder')
19
- input_img_args.add_argument('--img_filter', default=[], type=str,
20
- help='end string for images to run on')
21
- input_img_args.add_argument(
22
- '--channel_axis', default=-1, type=int,
23
- help='axis of image which corresponds to image channels')
24
- input_img_args.add_argument('--z_axis', default=0, type=int,
25
- help='axis of image which corresponds to Z dimension')
26
- input_img_args.add_argument(
27
- '--chan', default=0, type=int, help=
28
- 'Deprecated')
29
- input_img_args.add_argument(
30
- '--chan2', default=0, type=int, help=
31
- 'Deprecated'
32
- )
33
- input_img_args.add_argument('--invert', action='store_true',
34
- help='invert grayscale channel')
35
- input_img_args.add_argument(
36
- '--all_channels', action='store_true', help=
37
- 'deprecated')
38
- input_img_args.add_argument("--anisotropy", required=False, default=1.0, type=float,
39
- help="anisotropy of volume in 3D")
40
-
41
-
42
- # algorithm settings
43
- algorithm_args = parser.add_argument_group("algorithm arguments")
44
- algorithm_args.add_argument('--sharpen_radius', required=False, default=0.0,
45
- type=float, help='high-pass filtering radius. Default: %(default)s')
46
- algorithm_args.add_argument('--tile_norm', required=False, default=0, type=int,
47
- help='tile normalization block size. Default: %(default)s')
48
- algorithm_args.add_argument('--nimg_per_tif', required=False, default=10, type=int,
49
- help='number of crops in XY to save per tiff. Default: %(default)s')
50
- algorithm_args.add_argument('--crop_size', required=False, default=512, type=int,
51
- help='size of random crop to save. Default: %(default)s')
52
-
53
- args = parser.parse_args()
54
-
55
- # find images
56
- if len(args.img_filter) > 0:
57
- imf = args.img_filter
58
- else:
59
- imf = None
60
-
61
- if len(args.dir) > 0:
62
- image_names = io.get_image_files(args.dir, "_masks", imf=imf,
63
- look_one_level_down=args.look_one_level_down)
64
- dirname = args.dir
65
- else:
66
- if os.path.exists(args.image_path):
67
- image_names = [args.image_path]
68
- dirname = os.path.split(args.image_path)[0]
69
- else:
70
- raise ValueError(f"ERROR: no file found at {args.image_path}")
71
-
72
- np.random.seed(0)
73
- nimg_per_tif = args.nimg_per_tif
74
- crop_size = args.crop_size
75
- os.makedirs(os.path.join(dirname, 'train/'), exist_ok=True)
76
- pm = [(0, 1, 2, 3), (2, 0, 1, 3), (1, 0, 2, 3)]
77
- npm = ["YX", "ZY", "ZX"]
78
- for name in image_names:
79
- name0 = os.path.splitext(os.path.split(name)[-1])[0]
80
- img0 = io.imread_3D(name)
81
- try:
82
- img0 = transforms.convert_image(img0, channel_axis=args.channel_axis,
83
- z_axis=args.z_axis, do_3D=True)
84
- except ValueError:
85
- print('Error converting image. Did you provide the correct --channel_axis and --z_axis ?')
86
-
87
- for p in range(3):
88
- img = img0.transpose(pm[p]).copy()
89
- print(npm[p], img[0].shape)
90
- Ly, Lx = img.shape[1:3]
91
- imgs = img[np.random.permutation(img.shape[0])[:args.nimg_per_tif]]
92
- if args.anisotropy > 1.0 and p > 0:
93
- imgs = transforms.resize_image(imgs, Ly=int(args.anisotropy * Ly), Lx=Lx)
94
- for k, img in enumerate(imgs):
95
- if args.tile_norm:
96
- img = transforms.normalize99_tile(img, blocksize=args.tile_norm)
97
- if args.sharpen_radius:
98
- img = transforms.smooth_sharpen_img(img,
99
- sharpen_radius=args.sharpen_radius)
100
- ly = 0 if Ly - crop_size <= 0 else np.random.randint(0, Ly - crop_size)
101
- lx = 0 if Lx - crop_size <= 0 else np.random.randint(0, Lx - crop_size)
102
- io.imsave(os.path.join(dirname, f'train/{name0}_{npm[p]}_{k}.tif'),
103
- img[ly:ly + args.crop_size, lx:lx + args.crop_size].squeeze())
104
-
105
-
106
- if __name__ == '__main__':
107
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/seg_post_model/cellpose/gui/menus.py DELETED
@@ -1,145 +0,0 @@
1
- """
2
- Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
- """
4
- from qtpy.QtWidgets import QAction
5
- from . import io
6
-
7
-
8
- def mainmenu(parent):
9
- main_menu = parent.menuBar()
10
- file_menu = main_menu.addMenu("&File")
11
- # load processed data
12
- loadImg = QAction("&Load image (*.tif, *.png, *.jpg)", parent)
13
- loadImg.setShortcut("Ctrl+L")
14
- loadImg.triggered.connect(lambda: io._load_image(parent))
15
- file_menu.addAction(loadImg)
16
-
17
- parent.autoloadMasks = QAction("Autoload masks from _masks.tif file", parent,
18
- checkable=True)
19
- parent.autoloadMasks.setChecked(False)
20
- file_menu.addAction(parent.autoloadMasks)
21
-
22
- parent.disableAutosave = QAction("Disable autosave _seg.npy file", parent,
23
- checkable=True)
24
- parent.disableAutosave.setChecked(False)
25
- file_menu.addAction(parent.disableAutosave)
26
-
27
- parent.loadMasks = QAction("Load &masks (*.tif, *.png, *.jpg)", parent)
28
- parent.loadMasks.setShortcut("Ctrl+M")
29
- parent.loadMasks.triggered.connect(lambda: io._load_masks(parent))
30
- file_menu.addAction(parent.loadMasks)
31
- parent.loadMasks.setEnabled(False)
32
-
33
- loadManual = QAction("Load &processed/labelled image (*_seg.npy)", parent)
34
- loadManual.setShortcut("Ctrl+P")
35
- loadManual.triggered.connect(lambda: io._load_seg(parent))
36
- file_menu.addAction(loadManual)
37
-
38
- parent.saveSet = QAction("&Save masks and image (as *_seg.npy)", parent)
39
- parent.saveSet.setShortcut("Ctrl+S")
40
- parent.saveSet.triggered.connect(lambda: io._save_sets(parent))
41
- file_menu.addAction(parent.saveSet)
42
- parent.saveSet.setEnabled(False)
43
-
44
- parent.savePNG = QAction("Save masks as P&NG/tif", parent)
45
- parent.savePNG.setShortcut("Ctrl+N")
46
- parent.savePNG.triggered.connect(lambda: io._save_png(parent))
47
- file_menu.addAction(parent.savePNG)
48
- parent.savePNG.setEnabled(False)
49
-
50
- parent.saveOutlines = QAction("Save &Outlines as text for imageJ", parent)
51
- parent.saveOutlines.setShortcut("Ctrl+O")
52
- parent.saveOutlines.triggered.connect(lambda: io._save_outlines(parent))
53
- file_menu.addAction(parent.saveOutlines)
54
- parent.saveOutlines.setEnabled(False)
55
-
56
- parent.saveROIs = QAction("Save outlines as .zip archive of &ROI files for ImageJ",
57
- parent)
58
- parent.saveROIs.setShortcut("Ctrl+R")
59
- parent.saveROIs.triggered.connect(lambda: io._save_rois(parent))
60
- file_menu.addAction(parent.saveROIs)
61
- parent.saveROIs.setEnabled(False)
62
-
63
- parent.saveFlows = QAction("Save &Flows and cellprob as tif", parent)
64
- parent.saveFlows.setShortcut("Ctrl+F")
65
- parent.saveFlows.triggered.connect(lambda: io._save_flows(parent))
66
- file_menu.addAction(parent.saveFlows)
67
- parent.saveFlows.setEnabled(False)
68
-
69
-
70
- def editmenu(parent):
71
- main_menu = parent.menuBar()
72
- edit_menu = main_menu.addMenu("&Edit")
73
- parent.undo = QAction("Undo previous mask/trace", parent)
74
- parent.undo.setShortcut("Ctrl+Z")
75
- parent.undo.triggered.connect(parent.undo_action)
76
- parent.undo.setEnabled(False)
77
- edit_menu.addAction(parent.undo)
78
-
79
- parent.redo = QAction("Undo remove mask", parent)
80
- parent.redo.setShortcut("Ctrl+Y")
81
- parent.redo.triggered.connect(parent.undo_remove_action)
82
- parent.redo.setEnabled(False)
83
- edit_menu.addAction(parent.redo)
84
-
85
- parent.ClearButton = QAction("Clear all masks", parent)
86
- parent.ClearButton.setShortcut("Ctrl+0")
87
- parent.ClearButton.triggered.connect(parent.clear_all)
88
- parent.ClearButton.setEnabled(False)
89
- edit_menu.addAction(parent.ClearButton)
90
-
91
- parent.remcell = QAction("Remove selected cell (Ctrl+CLICK)", parent)
92
- parent.remcell.setShortcut("Ctrl+Click")
93
- parent.remcell.triggered.connect(parent.remove_action)
94
- parent.remcell.setEnabled(False)
95
- edit_menu.addAction(parent.remcell)
96
-
97
- parent.mergecell = QAction("FYI: Merge cells by Alt+Click", parent)
98
- parent.mergecell.setEnabled(False)
99
- edit_menu.addAction(parent.mergecell)
100
-
101
-
102
- def modelmenu(parent):
103
- main_menu = parent.menuBar()
104
- io._init_model_list(parent)
105
- model_menu = main_menu.addMenu("&Models")
106
- parent.addmodel = QAction("Add custom torch model to GUI", parent)
107
- #parent.addmodel.setShortcut("Ctrl+A")
108
- parent.addmodel.triggered.connect(parent.add_model)
109
- parent.addmodel.setEnabled(True)
110
- model_menu.addAction(parent.addmodel)
111
-
112
- parent.removemodel = QAction("Remove selected custom model from GUI", parent)
113
- #parent.removemodel.setShortcut("Ctrl+R")
114
- parent.removemodel.triggered.connect(parent.remove_model)
115
- parent.removemodel.setEnabled(True)
116
- model_menu.addAction(parent.removemodel)
117
-
118
- parent.newmodel = QAction("&Train new model with image+masks in folder", parent)
119
- parent.newmodel.setShortcut("Ctrl+T")
120
- parent.newmodel.triggered.connect(parent.new_model)
121
- parent.newmodel.setEnabled(False)
122
- model_menu.addAction(parent.newmodel)
123
-
124
- openTrainHelp = QAction("Training instructions", parent)
125
- openTrainHelp.triggered.connect(parent.train_help_window)
126
- model_menu.addAction(openTrainHelp)
127
-
128
-
129
- def helpmenu(parent):
130
- main_menu = parent.menuBar()
131
- help_menu = main_menu.addMenu("&Help")
132
-
133
- openHelp = QAction("&Help with GUI", parent)
134
- openHelp.setShortcut("Ctrl+H")
135
- openHelp.triggered.connect(parent.help_window)
136
- help_menu.addAction(openHelp)
137
-
138
- openGUI = QAction("&GUI layout", parent)
139
- openGUI.setShortcut("Ctrl+G")
140
- openGUI.triggered.connect(parent.gui_window)
141
- help_menu.addAction(openGUI)
142
-
143
- openTrainHelp = QAction("Training instructions", parent)
144
- openTrainHelp.triggered.connect(parent.train_help_window)
145
- help_menu.addAction(openTrainHelp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/seg_post_model/cellpose/vit_sam_new.py DELETED
@@ -1,197 +0,0 @@
1
- """
2
- Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
3
- """
4
-
5
- import torch
6
- from segment_anything import sam_model_registry
7
- torch.backends.cuda.matmul.allow_tf32 = True
8
- from torch import nn
9
- import torch.nn.functional as F
10
-
11
- class Transformer(nn.Module):
12
- def __init__(self, backbone="vit_l", ps=16, nout=3, bsize=256, rdrop=0.4,
13
- checkpoint=None, dtype=torch.float32):
14
- super(Transformer, self).__init__()
15
- """
16
- print(self.encoder.patch_embed)
17
- PatchEmbed(
18
- (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
19
- )
20
- print(self.encoder.neck)
21
- Sequential(
22
- (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
23
- (1): LayerNorm2d()
24
- (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
25
- (3): LayerNorm2d()
26
- )
27
- """
28
- # instantiate the vit model, default to not loading SAM
29
- # checkpoint = sam_vit_l_0b3195.pth is standard pretrained SAM
30
- if checkpoint is None:
31
- checkpoint = "sam_vit_l_0b3195.pth"
32
- self.encoder = sam_model_registry[backbone](checkpoint).image_encoder
33
- w = self.encoder.patch_embed.proj.weight.detach()
34
- nchan = w.shape[0]
35
-
36
- # change token size to ps x ps
37
- self.ps = ps
38
- # self.encoder.patch_embed.proj = nn.Conv2d(3, nchan, stride=ps, kernel_size=ps)
39
- # self.encoder.patch_embed.proj.weight.data = w[:,:,::16//ps,::16//ps]
40
-
41
- # adjust position embeddings for new bsize and new token size
42
- ds = (1024 // 16) // (bsize // ps)
43
- self.encoder.pos_embed = nn.Parameter(self.encoder.pos_embed[:,::ds,::ds], requires_grad=True)
44
-
45
- # readout weights for nout output channels
46
- # if nout is changed, weights will not load correctly from pretrained Cellpose-SAM
47
- self.nout = nout
48
- self.out = nn.Conv2d(256, self.nout * ps**2, kernel_size=1)
49
-
50
- # W2 reshapes token space to pixel space, not trainable
51
- self.W2 = nn.Parameter(torch.eye(self.nout * ps**2).reshape(self.nout*ps**2, self.nout, ps, ps),
52
- requires_grad=False)
53
-
54
- # fraction of layers to drop at random during training
55
- self.rdrop = rdrop
56
-
57
- # average diameter of ROIs from training images from fine-tuning
58
- self.diam_labels = nn.Parameter(torch.tensor([30.]), requires_grad=False)
59
- # average diameter of ROIs during main training
60
- self.diam_mean = nn.Parameter(torch.tensor([30.]), requires_grad=False)
61
-
62
- # set attention to global in every layer
63
- for blk in self.encoder.blocks:
64
- blk.window_size = 0
65
-
66
- self.dtype = dtype
67
-
68
- def forward(self, x, feat=None):
69
- # same progression as SAM until readout
70
- x = self.encoder.patch_embed(x)
71
- if feat is not None:
72
- feat = self.encoder.patch_embed(feat)
73
- x = x + x * feat * 0.5
74
-
75
- if self.encoder.pos_embed is not None:
76
- x = x + self.encoder.pos_embed
77
-
78
- if self.training and self.rdrop > 0:
79
- nlay = len(self.encoder.blocks)
80
- rdrop = (torch.rand((len(x), nlay), device=x.device) <
81
- torch.linspace(0, self.rdrop, nlay, device=x.device)).to(x.dtype)
82
- for i, blk in enumerate(self.encoder.blocks):
83
- mask = rdrop[:,i].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
84
- x = x * mask + blk(x) * (1-mask)
85
- else:
86
- for blk in self.encoder.blocks:
87
- x = blk(x)
88
-
89
- x = self.encoder.neck(x.permute(0, 3, 1, 2))
90
-
91
- # readout is changed here
92
- x1 = self.out(x)
93
- x1 = F.conv_transpose2d(x1, self.W2, stride = self.ps, padding = 0)
94
-
95
- # maintain the second output of feature size 256 for backwards compatibility
96
-
97
- return x1, torch.randn((x.shape[0], 256), device=x.device)
98
-
99
- def load_model(self, PATH, device, strict = False):
100
- state_dict = torch.load(PATH, map_location = device, weights_only=True)
101
- keys = [k for k in state_dict.keys()]
102
- if keys[0][:7] == "module.":
103
- from collections import OrderedDict
104
- new_state_dict = OrderedDict()
105
- for k, v in state_dict.items():
106
- name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel
107
- new_state_dict[name] = v
108
- self.load_state_dict(new_state_dict, strict = strict)
109
- else:
110
- self.load_state_dict(state_dict, strict = strict)
111
-
112
- if self.dtype != torch.float32:
113
- self = self.to(self.dtype)
114
-
115
-
116
- @property
117
- def device(self):
118
- """
119
- Get the device of the model.
120
-
121
- Returns:
122
- torch.device: The device of the model.
123
- """
124
- return next(self.parameters()).device
125
-
126
- def save_model(self, filename):
127
- """
128
- Save the model to a file.
129
-
130
- Args:
131
- filename (str): The path to the file where the model will be saved.
132
- """
133
- torch.save(self.state_dict(), filename)
134
-
135
-
136
-
137
- class CPnetBioImageIO(Transformer):
138
- """
139
- A subclass of the CP-SAM model compatible with the BioImage.IO Spec.
140
-
141
- This subclass addresses the limitation of CPnet's incompatibility with the BioImage.IO Spec,
142
- allowing the CPnet model to use the weights uploaded to the BioImage.IO Model Zoo.
143
- """
144
-
145
- def forward(self, x):
146
- """
147
- Perform a forward pass of the CPnet model and return unpacked tensors.
148
-
149
- Args:
150
- x (torch.Tensor): Input tensor.
151
-
152
- Returns:
153
- tuple: A tuple containing the output tensor, style tensor, and downsampled tensors.
154
- """
155
- output_tensor, style_tensor, downsampled_tensors = super().forward(x)
156
- return output_tensor, style_tensor, *downsampled_tensors
157
-
158
-
159
- def load_model(self, filename, device=None):
160
- """
161
- Load the model from a file.
162
-
163
- Args:
164
- filename (str): The path to the file where the model is saved.
165
- device (torch.device, optional): The device to load the model on. Defaults to None.
166
- """
167
- if (device is not None) and (device.type != "cpu"):
168
- state_dict = torch.load(filename, map_location=device, weights_only=True)
169
- else:
170
- self.__init__(self.nout)
171
- state_dict = torch.load(filename, map_location=torch.device("cpu"),
172
- weights_only=True)
173
-
174
- self.load_state_dict(state_dict)
175
-
176
- def load_state_dict(self, state_dict):
177
- """
178
- Load the state dictionary into the model.
179
-
180
- This method overrides the default `load_state_dict` to handle Cellpose's custom
181
- loading mechanism and ensures compatibility with BioImage.IO Core.
182
-
183
- Args:
184
- state_dict (Mapping[str, Any]): A state dictionary to load into the model
185
- """
186
- if state_dict["output.2.weight"].shape[0] != self.nout:
187
- for name in self.state_dict():
188
- if "output" not in name:
189
- self.state_dict()[name].copy_(state_dict[name])
190
- else:
191
- super().load_state_dict(
192
- {name: param for name, param in state_dict.items()},
193
- strict=False)
194
-
195
-
196
-
197
-