foz commited on
Commit
1d89ca0
2 Parent(s): d4c993e 29f7eb3

Fix README

Browse files
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ from tifffile import imread
4
+ from PIL import Image
5
+ from path_analysis.analyse import analyse_paths
6
+ import numpy as np
7
+
8
+
9
+ # Function to preview the imported image
10
+ def preview_image(file1):
11
+ if file1:
12
+ print('Uploading image', file1.name)
13
+ im = imread(file1.name)
14
+ print(im.ndim, im.shape)
15
+ if im.ndim>2:
16
+ return Image.fromarray(np.max(im, axis=0))
17
+ else:
18
+ return Image.fromarray(im)
19
+ else:
20
+ return None
21
+
22
+
23
+ with gr.Blocks() as demo:
24
+ with gr.Row():
25
+ with gr.Column():
26
+ # Inputs for cell ID, image, and path
27
+ cellid_input = gr.Textbox(label="Cell ID", placeholder="Image_1")
28
+ image_input = gr.File(label="Input foci image")
29
+ image_preview = gr.Image(label="Max projection of foci image")
30
+ image_input.change(fn=preview_image, inputs=image_input, outputs=image_preview)
31
+ path_input = gr.File(label="SNT traces file")
32
+
33
+ # Additional options wrapped in an accordion for better UI experience
34
+ with gr.Accordion("Additional options ..."):
35
+ sphere_radius = gr.Number(label="Trace sphere radius (um)", value=0.1984125, interactive=True)
36
+ peak_threshold = gr.Number(label="Peak relative threshold", value=0.4, interactive=True)
37
+ # Resolutions for xy and z axis
38
+ with gr.Row():
39
+ xy_res = gr.Number(label='xy-yesolution (um)', value=0.0396825, interactive=True)
40
+ z_res = gr.Number(label='z resolution (um)', value=0.0909184, interactive=True)
41
+ # Resolutions for xy and z axis
42
+
43
+ threshold_type = gr.Radio(["per-trace", "per-cell"], label="Threshold-type", value="per-trace", interactive=True)
44
+ use_corrected_positions = gr.Checkbox(label="Correct foci position measurements", value=True, interactive=True)
45
+ screening_distance = gr.Number(label='Screening distance (voxels)', value=10, interactive=True)
46
+
47
+
48
+ # The output column showing the result of processing
49
+ with gr.Column():
50
+ trace_output = gr.Image(label="Overlayed paths")
51
+ image_output=gr.Gallery(label="Traced paths")
52
+ plot_output=gr.Plot(label="Foci intensity traces")
53
+ data_output=gr.DataFrame(label="Detected peak data")#, "Peak 1 pos", "Peak 1 int"])
54
+ data_file_output=gr.File(label="Output data file (.csv)")
55
+
56
+
57
+ def process(cellid_input, image_input, path_input, sphere_radius, peak_threshold, xy_res, z_res, threshold_type, use_corrected_positions, screening_distance):
58
+
59
+ config = { 'sphere_radius': sphere_radius,
60
+ 'peak_threshold': peak_threshold,
61
+ 'xy_res': xy_res,
62
+ 'z_res': z_res,
63
+ 'threshold_type': threshold_type,
64
+ 'use_corrected_positions': use_corrected_positions,
65
+ 'screening_distance': screening_distance,
66
+ }
67
+
68
+
69
+ paths, traces, fig, extracted_peaks = analyse_paths(cellid_input, image_input.name, path_input.name, config)
70
+ extracted_peaks.to_csv('output.csv')
71
+ print('extracted', extracted_peaks)
72
+ return paths, [Image.fromarray(im) for im in traces], fig, extracted_peaks, 'output.csv'
73
+
74
+
75
+ with gr.Row():
76
+ greet_btn = gr.Button("Process")
77
+ greet_btn.click(fn=process, inputs=[cellid_input, image_input, path_input, sphere_radius, peak_threshold, xy_res, z_res, threshold_type, use_corrected_positions, screening_distance], outputs=[trace_output, image_output, plot_output, data_output, data_file_output], api_name="process")
78
+
79
+
80
+ if __name__ == "__main__":
81
+ demo.launch()
path_analysis/__init__.py ADDED
File without changes
path_analysis/analyse.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import lxml.etree as ET
3
+ import gzip
4
+ import tifffile
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from PIL import Image, ImageDraw
8
+ import pandas as pd
9
+ from itertools import cycle
10
+ from .data_preprocess import analyse_traces
11
+ import math
12
+ import scipy.linalg as la
13
+
14
+
15
+ def get_paths_from_traces_file(traces_file):
16
+ """
17
+ Parses the specified traces file and extracts paths and their lengths.
18
+
19
+ Args:
20
+ traces_file (str): Path to the XML traces file.
21
+
22
+ Returns:
23
+ tuple: A tuple containing a list of paths (each path is a list of tuples representing points)
24
+ and a list of corresponding path lengths.
25
+ """
26
+ tree = ET.parse(traces_file)
27
+ root = tree.getroot()
28
+ all_paths = []
29
+ path_lengths = []
30
+ for path in root.findall('path'):
31
+ length=path.get('reallength')
32
+ path_points = []
33
+ for point in path:
34
+ path_points.append((int(point.get('x')), int(point.get('y')), int(point.get('z'))))
35
+ all_paths.append(path_points)
36
+ path_lengths.append(float(length))
37
+ return all_paths, path_lengths
38
+
39
+
40
+
41
+ def calculate_path_length_partials(point_list, voxel_size=(1,1,1)):
42
+ """
43
+ Calculate the partial path length of a series of points.
44
+
45
+ Args:
46
+ point_list (list of tuple): List of points, each represented as a tuple of coordinates (x, y, z).
47
+ voxel_size (tuple, optional): Size of the voxel in each dimension (x, y, z). Defaults to (1, 1, 1).
48
+
49
+ Returns:
50
+ numpy.ndarray: Array of cumulative partial path lengths at each point.
51
+ """
52
+ # Simple calculation
53
+ section_lengths = [0.0]
54
+ s = np.array(voxel_size)
55
+ for i in range(len(point_list)-1):
56
+ # Euclidean distance between successive points
57
+ section_lengths.append(la.norm(s * (np.array(point_list[i+1]) - np.array(point_list[i]))))
58
+ return np.cumsum(section_lengths)
59
+
60
+
61
+ def visualise_ordering(points_list, dim, wr=5, wc=5):
62
+ """
63
+ Visualize the ordering of points in an image.
64
+
65
+ Args:
66
+ points_list (list): List of points to be visualized.
67
+ dim (tuple): Dimensions of the image (rows, columns, channels).
68
+ wr (int, optional): Width of the region to visualize around the point in the row direction. Defaults to 5.
69
+ wc (int, optional): Width of the region to visualize around the point in the column direction. Defaults to 5.
70
+
71
+ Returns:
72
+ np.array: An image array with visualized points.
73
+ """
74
+ # Visualizes the ordering of the points in the list on a blank image.
75
+ rdim, cdim, _ = dim
76
+ vis = np.zeros((rdim, cdim, 3), dtype=np.uint8)
77
+
78
+ def get_col(i):
79
+ r = int(255 * i/len(points_list))
80
+ g = 255 - r
81
+ return r, g, 0
82
+
83
+ for n, p in enumerate(points_list):
84
+ c, r, _ = map(int, p)
85
+ vis[max(0,r-wr):min(rdim,r+wr+1),max(0,c-wc):min(cdim,c+wc+1)] = get_col(n)
86
+
87
+ return vis
88
+
89
+ # A color map for paths
90
+ col_map = [(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255), (0,255,255),
91
+ (255,127,0), (255, 0, 127), (127, 255, 0), (0, 255, 127), (127,0,255), (0,127,255)]
92
+
93
+ def draw_paths(all_paths, foci_stack, foci_index=None, r=3, screened_foci_data=None):
94
+ """
95
+ Draws paths on the provided image stack and overlays markers for the foci
96
+
97
+ Args:
98
+ all_paths (list): List of paths where each path is a list of points.
99
+ foci_stack (np.array): 3D numpy array representing the image stack.
100
+ foci_index (list, optional): List of list of focus indices (along each path). Defaults to None.
101
+ r (int, optional): Radius for the ellipse or line drawing around the focus. Defaults to 3.
102
+ screened_foci_data (list, optional): List of RemovedPeakData for screened foci
103
+ Returns:
104
+ PIL.Image.Image: An image with the drawn paths.
105
+ """
106
+ im = np.max(foci_stack, axis=0)
107
+ im = (im/np.max(im)*255).astype(np.uint8)
108
+ im = np.dstack((im,)*3)
109
+ im = Image.fromarray(im)
110
+ draw = ImageDraw.Draw(im)
111
+ for i, (p, col) in enumerate(zip(all_paths, cycle(col_map))):
112
+ draw.line([(u[0], u[1]) for u in p], fill=col)
113
+ draw.text((p[0][0], p[0][1]), str(i+1), fill=col)
114
+
115
+ if screened_foci_data is not None:
116
+ for i, removed_peaks in enumerate(screened_foci_data):
117
+ for p in removed_peaks:
118
+ u = all_paths[i][p.idx]
119
+ v = all_paths[p.screening_peak[0]][p.screening_peak[1]]
120
+ draw.line((int(u[0]), int(u[1]), int(v[0]), int(v[1])), fill=(127,127,127), width=2)
121
+
122
+ if foci_index is not None:
123
+ for i, (idx, p, col) in enumerate(zip(foci_index, all_paths, cycle(col_map))):
124
+ if len(idx):
125
+ for j in idx:
126
+ draw.line((int(p[j][0]-r), int(p[j][1]), int(p[j][0]+r), int(p[j][1])), fill=col, width=2)
127
+ draw.line((int(p[j][0]), int(p[j][1]-r), int(p[j][0]), int(p[j][1]+r)), fill=col, width=2)
128
+ return im
129
+
130
+
131
+ def measure_from_mask(mask, measure_stack):
132
+ """
133
+ Compute the sum of measure_stack values where the mask is equal to 1.
134
+
135
+ Args:
136
+ mask (numpy.ndarray): Binary mask where the measurement should be applied.
137
+ measure_stack (numpy.ndarray): Stack of measurements.
138
+
139
+ Returns:
140
+ measure_stack.dtype: Sum of measure_stack values where the mask is 1.
141
+ """
142
+ return np.sum(mask * measure_stack)
143
+
144
+ # Max of measure_stack over region where mask==1
145
+ def max_from_mask(mask, measure_stack):
146
+ """
147
+ Compute the maximum of measure_stack values where the mask is equal to 1.
148
+
149
+ Args:
150
+ mask (numpy.ndarray): Binary mask where the measurement should be applied.
151
+ measure_stack (numpy.ndarray): Stack of measurements.
152
+
153
+ Returns:
154
+ measure_stack.dtype: Maximum value of measure_stack where the mask is 1.
155
+ """
156
+ return np.max(mask * measure_stack)
157
+
158
+ def make_mask_s(p, melem, measure_stack):
159
+ """
160
+ Translate a mask to point p, ensuring correct treatment near the edges of the measure_stack.
161
+
162
+ Args:
163
+ p (tuple): Target point (r, c, z).
164
+ melem (numpy.ndarray): Structuring element for the mask.
165
+ measure_stack (numpy.ndarray): Stack of measurements.
166
+
167
+ Returns:
168
+ tuple: A tuple containing the translated mask and a section of the measure_stack.
169
+ """
170
+
171
+
172
+ #
173
+
174
+ R = [u//2 for u in melem.shape]
175
+
176
+ r, c, z = p
177
+
178
+ mask = np.zeros(melem.shape)
179
+
180
+ m_data = np.zeros(melem.shape)
181
+ s = measure_stack.shape
182
+ o_1, o_2, o_3 = max(R[0]-r, 0), max(R[1]-c, 0), max(R[2]-z,0)
183
+ e_1, e_2, e_3 = min(R[0]-r+s[0], 2*R[0]+1), min(R[1]-c+s[1], 2*R[1]+1), min(R[2]-z+s[2], 2*R[2]+1)
184
+ m_data[o_1:e_1,o_2:e_2,o_3:e_3] = measure_stack[max(r-R[0],0):min(r+R[0]+1,s[0]),max(c-R[1],0):min(c+R[1]+1,s[1]),max(z-R[2],0):min(z+R[2]+1, s[2])]
185
+ mask[o_1:e_1,o_2:e_2,o_3:e_3] = melem[o_1:e_1,o_2:e_2,o_3:e_3]
186
+
187
+
188
+ return mask, m_data
189
+
190
+
191
+ def measure_at_point(p, melem, measure_stack, op='mean'):
192
+ """
193
+ Measure the mean or max value of measure_stack around a specific point using a structuring element.
194
+
195
+ Args:
196
+ p (tuple): Target point (r, c, z).
197
+ melem (numpy.ndarray): Structuring element for the mask.
198
+ measure_stack (numpy.ndarray): Stack of measurements.
199
+ op (str, optional): Operation to be applied; either 'mean' or 'max'. Default is 'mean'.
200
+
201
+ Returns:
202
+ float: Measured value based on the specified operation.
203
+ """
204
+
205
+ p = map(int, p)
206
+ if op=='mean':
207
+ mask, m_data = make_mask_s(p, melem, measure_stack)
208
+ melem_size = np.sum(mask)
209
+ return float(measure_from_mask(mask, m_data) / melem_size)
210
+ else:
211
+ mask, m_data = make_mask_s(p, melem, measure_stack)
212
+ return float(max_from_mask(mask, m_data))
213
+
214
+ # Generate spherical region
215
+ def make_sphere(R=5, z_scale_ratio=2.3):
216
+ """
217
+ Generate a binary representation of a sphere in 3D space.
218
+
219
+ Args:
220
+ R (int, optional): Radius of the sphere. Default is 5. Centred on the centre of the middle voxel.
221
+ Includes all voxels whose centre is precisely R from the middle voxel.
222
+ z_scale_ratio (float, optional): Scaling factor for the z-axis. Default is 2.3.
223
+
224
+ Returns:
225
+ numpy.ndarray: Binary representation of the sphere.
226
+ """
227
+ R_z = int(math.ceil(R/z_scale_ratio))
228
+ x, y, z = np.ogrid[-R:R+1, -R:R+1, -R_z:R_z+1]
229
+ sphere = x**2 + y**2 + (z_scale_ratio * z)**2 <= R**2
230
+ return sphere
231
+
232
+ # Measure the values of measure_stack at each of the points of points_list in turn.
233
+ # Measurement is the mean / max (specified by op) on the spherical region about each point
234
+ def measure_all_with_sphere(points_list, measure_stack, op='mean', R=5, z_scale_ratio=2.3):
235
+ """
236
+ Measure the values of measure_stack at each point in a list using a spherical region.
237
+
238
+ Args:
239
+ points_list (list): List of points (r, c, z) to be measured.
240
+ measure_stack (numpy.ndarray): Stack of measurements.
241
+ op (str, optional): Operation to be applied; either 'mean' or 'max'. Default is 'mean'.
242
+ R (int, optional): Radius of the sphere. Default is 5.
243
+ z_scale_ratio (float, optional): Scaling factor for the z-axis. Default is 2.3.
244
+
245
+ Returns:
246
+ list: List of measured values for each point.
247
+ """
248
+ melem = make_sphere(R, z_scale_ratio)
249
+ measure_func = lambda p: measure_at_point(p, melem, measure_stack, op)
250
+ return list(map(measure_func, points_list))
251
+
252
+
253
+ # Measure fluorescence levels along ordered skeleton
254
+ def measure_chrom2(path, intensity, config):
255
+ """
256
+ Measure fluorescence levels along an ordered skeleton.
257
+
258
+ Args:
259
+ path (list): List of ordered path points (r, c, z).
260
+ intensity (numpy.ndarray): 3D fluorescence data.
261
+ config (dict): Configuration dictionary containing 'z_res', 'xy_res', and 'sphere_radius' values.
262
+
263
+ Returns:
264
+ tuple: A tuple containing the visualization, mean measurements, and max measurements along the path.
265
+ """
266
+ # Calculate size of spheroid used for measurement
267
+ scale_ratio = config['z_res']/config['xy_res']
268
+ sphere_xy_radius = int(math.ceil(config['sphere_radius']/config['xy_res']))
269
+
270
+ vis = visualise_ordering(path, dim=intensity.shape, wr=sphere_xy_radius, wc=sphere_xy_radius)
271
+
272
+ measurements = measure_all_with_sphere(path, intensity, op='mean', R=sphere_xy_radius, z_scale_ratio=scale_ratio)
273
+ measurements_max = measure_all_with_sphere(path, intensity, op='max', R=sphere_xy_radius, z_scale_ratio=scale_ratio)
274
+
275
+
276
+ return vis, measurements, measurements_max
277
+
278
+ def extract_peaks(cell_id, all_paths, path_lengths, measured_traces, config):
279
+ """
280
+ Extract peak information from given traces and compile them into a DataFrame.
281
+
282
+ Args:
283
+ - cell_id (int or str): Identifier for the cell being analyzed.
284
+ - all_paths (list of lists): Contains ordered path points for multiple paths.
285
+ - path_lengths (list of floats): List containing lengths of each path in all_paths.
286
+ - measured_traces (list of lists): Contains fluorescence measurement values along the paths.
287
+ - config (dict): Configuration dictionary containing:
288
+ - 'peak_threshold': Threshold value to determine a peak in the trace.
289
+ - 'sphere_radius': Radius of the sphere used in fluorescence measurement.
290
+
291
+ Returns:
292
+ - pd.DataFrame: DataFrame containing peak information for each path.
293
+ - list of lists: Absolute intensities of the detected foci.
294
+ - list of lists: Index positions of the detected foci.
295
+ - list of lists: Absolute focus intensity threshold for each trace.
296
+ - list of numpy.ndarray: For each trace, distances of each point from start of trace in microns
297
+ """
298
+
299
+ n_paths = len(all_paths)
300
+
301
+ data = []
302
+ foci_absolute_intensity, foci_position, foci_position_index, screened_foci_data, trace_median_intensities, trace_thresholds = analyse_traces(all_paths, path_lengths, measured_traces, config)
303
+
304
+ # Normalize foci intensities (for quantification) using trace medians as estimates of background
305
+ foci_intensities = []
306
+ for path_foci_abs_int, tmi in zip(foci_absolute_intensity, trace_median_intensities):
307
+ foci_intensities.extend(list(path_foci_abs_int - tmi))
308
+
309
+ # Divide all foci intensities by the mean within the cell
310
+ mean_intensity = np.mean(foci_intensities)
311
+ trace_positions = []
312
+
313
+ for i in range(n_paths):
314
+
315
+ # Calculate real (Euclidean) distance of each point along the traced path
316
+ pl = calculate_path_length_partials(all_paths[i], (config['xy_res'], config['xy_res'], config['z_res']))
317
+
318
+
319
+ path_data = { 'Cell_ID':cell_id,
320
+ 'Trace': i+1,
321
+ 'SNT_trace_length(um)': path_lengths[i],
322
+ 'Measured_trace_length(um)': pl[-1],
323
+ 'Trace_median_intensity': trace_median_intensities[i],
324
+ 'Detection_sphere_radius(um)': config['sphere_radius'],
325
+ 'Screening_distance(voxels)': config['screening_distance'],
326
+ 'Foci_ID_threshold': config['peak_threshold'],
327
+ 'Trace_foci_number': len(foci_position_index[i]) }
328
+ for j, (idx, u,v) in enumerate(zip(foci_position_index[i], foci_position[i], foci_absolute_intensity[i])):
329
+ if config['use_corrected_positions']:
330
+ # Use the calculated position along the traced path
331
+ path_data[f'Foci_{j+1}_position(um)'] = pl[idx]
332
+ else:
333
+ # Use the measured trace length (from SNT), and assume all steps of path are approximately the same length
334
+ path_data[f'Foci_{j+1}_position(um)'] = u
335
+ # The original measured intensity (mean in spheroid around detected peak)
336
+ path_data[f'Foci_{j+1}_absolute_intensity'] = v
337
+ # Measure relative intensity by removing per-trace background and dividing by cell total
338
+ path_data[f'Foci_{j+1}_relative_intensity'] = (v - trace_median_intensities[i])/mean_intensity
339
+ data.append(path_data)
340
+ trace_positions.append(pl)
341
+ return pd.DataFrame(data), foci_absolute_intensity, foci_position_index, screened_foci_data, trace_thresholds, trace_positions
342
+
343
+
344
+ def analyse_paths(cell_id,
345
+ foci_file,
346
+ traces_file,
347
+ config
348
+ ):
349
+ """
350
+ Analyzes paths for the given cell ID using provided foci and trace files.
351
+
352
+ Args:
353
+ cell_id (int/str): Identifier for the cell.
354
+ foci_file (str): Path to the foci image file.
355
+ traces_file (str): Path to the XML traces file.
356
+ config (dict): Configuration dictionary containing necessary parameters such as resolutions and thresholds.
357
+
358
+ Returns:
359
+ tuple: A tuple containing an overlay image of the traces, visualization images for each trace,
360
+ a figure with plotted measurements, and a dataframe with extracted peaks.
361
+ """
362
+
363
+
364
+ # Read stack
365
+
366
+ foci_stack = tifffile.imread(foci_file)
367
+
368
+ # If 2D add additional (z) dimension
369
+ if foci_stack.ndim==2:
370
+ foci_stack = foci_stack[None,:,:]
371
+
372
+ all_paths, path_lengths = get_paths_from_traces_file(traces_file)
373
+
374
+ all_trace_vis = [] # Per-path visualizations
375
+ all_m = [] # Per-path measured intensities
376
+ for p in all_paths:
377
+ # Measure intensity along path - transpose the stack ZYX -> XYZ
378
+ vis, m, _ = measure_chrom2(p,foci_stack.transpose(2,1,0), config)
379
+ all_trace_vis.append(vis)
380
+ all_m.append(m)
381
+
382
+
383
+ # Extract all data from paths and traces
384
+ extracted_peaks, foci_absolute_intensity, foci_pos_index, screened_foci_data, trace_thresholds, trace_positions = extract_peaks(cell_id, all_paths, path_lengths, all_m, config)
385
+
386
+ # Plot per-path measured intensities and indicate foci
387
+ n_cols = 2
388
+ n_rows = (len(all_paths)+n_cols-1)//n_cols
389
+ fig, ax = plt.subplots(n_rows,n_cols, figsize=(5*n_cols, 3*n_rows))
390
+ ax = ax.flatten()
391
+
392
+ for i, m in enumerate(all_m):
393
+ ax[i].set_title(f'Trace {i+1}')
394
+ ax[i].plot(trace_positions[i], m)
395
+ if len(foci_pos_index[i]):
396
+ # Plot detected foci
397
+ ax[i].plot(trace_positions[i][foci_pos_index[i]], np.array(m)[foci_pos_index[i]], 'rx')
398
+
399
+ if len(screened_foci_data[i]):
400
+ # Indicate screened foci by gray circles on plots
401
+ screened_foci_pos_index = [u.idx for u in screened_foci_data[i]]
402
+ ax[i].plot(trace_positions[i][screened_foci_pos_index], np.array(m)[screened_foci_pos_index], color=(0.5,0.5,0.5), marker='o', linestyle='None')
403
+
404
+ # Show per-trace intensity thresholds with red dotted lines
405
+ if trace_thresholds[i] is not None:
406
+ ax[i].axhline(trace_thresholds[i], c='r', ls=':')
407
+ ax[i].set_xlabel('Distance from start (um)')
408
+ ax[i].set_ylabel('Intensity')
409
+ # Hide excess plots
410
+ for i in range(len(all_m), n_cols*n_rows):
411
+ ax[i].axis('off')
412
+
413
+ plt.tight_layout()
414
+ trace_overlay = draw_paths(all_paths, foci_stack, foci_index=foci_pos_index, screened_foci_data=screened_foci_data)
415
+
416
+ return trace_overlay, all_trace_vis, fig, extracted_peaks
path_analysis/data_preprocess.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from dataclasses import dataclass
4
+ import numpy as np
5
+ import scipy.linalg as la
6
+ from scipy.signal import find_peaks
7
+ from math import ceil
8
+
9
+
10
+
11
+
12
+ def thin_peaks(peak_list, dmin=10, voxel_size=(1,1,1), return_larger_peaks=False):
13
+ """
14
+ Remove peaks within a specified distance of each other, retaining the peak with the highest intensity.
15
+
16
+ Args:
17
+ - peak_list (list of PeakData): Each element contains:
18
+ - pos (list of float): 3D coordinates of the peak.
19
+ - intensity (float): The intensity value of the peak.
20
+ - key (tuple): A unique identifier or index for the peak (#trace, #peak)
21
+ - dmin (float, optional): Minimum distance between peaks. peaks closer than this threshold will be thinned. Defaults to 10.
22
+ - return_larger_peaks (bool, optional): Indicate larger peak for each thinned peak
23
+
24
+ Returns:
25
+ - list of tuples: A list containing keys of the removed peaks.
26
+ if return_larger_peaks
27
+ - list of tuples: A list containing the keys of the larger peak causing the peak to be removed
28
+
29
+ Notes:
30
+ - The function uses the L2 norm (Euclidean distance) to compute the distance between peaks.
31
+ - When two peaks are within `dmin` distance, the peak with the lower intensity is removed.
32
+ """
33
+ removed_peaks = []
34
+ removed_larger_peaks = []
35
+ for i in range(len(peak_list)):
36
+ if peak_list[i].key in removed_peaks:
37
+ continue
38
+ for j in range(len(peak_list)):
39
+ if i==j:
40
+ continue
41
+ if peak_list[j].key in removed_peaks:
42
+ continue
43
+ d = (np.array(peak_list[i].pos) - np.array(peak_list[j].pos))*np.array(voxel_size)
44
+ d = la.norm(d)
45
+ if d<dmin:
46
+ hi = peak_list[i].intensity
47
+ hj = peak_list[j].intensity
48
+ if hi<hj:
49
+ removed_peaks.append(peak_list[i].key)
50
+ removed_larger_peaks.append(peak_list[j].key)
51
+ break
52
+ else:
53
+ removed_peaks.append(peak_list[j].key)
54
+ removed_larger_peaks.append(peak_list[i].key)
55
+
56
+ if return_larger_peaks:
57
+ return removed_peaks, removed_larger_peaks
58
+ else:
59
+ return removed_peaks
60
+
61
+
62
+ @dataclass
63
+ class CellData(object):
64
+ """Represents data related to a single cell.
65
+
66
+ Attributes:
67
+ pathdata_list (list): A list of PathData objects representing the various paths associated with the cell.
68
+ """
69
+ pathdata_list: list
70
+
71
+ @dataclass
72
+ class RemovedPeakData(object):
73
+ """Represents data related to a removed peak
74
+
75
+ Attributes:
76
+ idx (int): Index of peak along path
77
+ screening_peak (tuple): (path_idx, position along path) for screening peak
78
+ """
79
+ idx: int
80
+ screening_peak: tuple
81
+
82
+ @dataclass
83
+ class PathData(object):
84
+ """Represents data related to a specific path in the cell.
85
+
86
+ This dataclass encapsulates information about the peaks,
87
+ the defining points, the fluorescence values, and the path length of a specific path.
88
+
89
+ Attributes: peaks (list): List of peaks in the path (indicies of positions in points, o_intensity).
90
+ removed_peaks (list): List of peaks in the path which have been removed because of a nearby larger peak
91
+ points (list): List of points defining the path.
92
+ o_intensity (list): List of (unnormalized) fluorescence intensity values along the path
93
+ SC_length (float): Length of the path.
94
+
95
+ """
96
+ peaks: list
97
+ removed_peaks: list
98
+ points: list
99
+ o_intensity: list
100
+ SC_length: float
101
+
102
+ @dataclass
103
+ class PeakData(object):
104
+ pos: tuple
105
+ intensity: float
106
+ key: tuple
107
+
108
+
109
+ def find_peaks2(v, distance=5, prominence=0.5):
110
+ """
111
+ Find peaks in a 1D array with extended boundary handling.
112
+
113
+ The function pads the input array at both ends to handle boundary peaks. It then identifies peaks in the extended array
114
+ and maps them back to the original input array.
115
+
116
+ Args:
117
+ - v (numpy.ndarray): 1D input array in which to find peaks.
118
+ - distance (int, optional): Minimum number of array elements that separate two peaks. Defaults to 5.
119
+ - prominence (float, optional): Minimum prominence required for a peak to be identified. Defaults to 0.5.
120
+
121
+ Returns:
122
+ - list of int: List containing the indices of the identified peaks in the original input array.
123
+ - dict: Information about the properties of the identified peaks (as returned by scipy.signal.find_peaks).
124
+
125
+ """
126
+ pad = int(ceil(distance))+1
127
+ v_ext = np.concatenate([np.ones((pad,), dtype=v.dtype)*np.min(v), v, np.ones((pad,), dtype=v.dtype)*np.min(v)])
128
+
129
+ assert(len(v_ext) == len(v)+2*pad)
130
+ peaks, _ = find_peaks(v_ext, distance=distance, prominence=prominence)
131
+ peaks = peaks - pad
132
+ n_peaks = []
133
+ for i in peaks:
134
+ if 0<=i<len(v):
135
+ n_peaks.append(i)
136
+ else:
137
+ raise Exception
138
+ return n_peaks, _
139
+
140
+
141
+ def process_cell_traces(all_paths, path_lengths, measured_trace_fluorescence, dmin=10):
142
+ """
143
+ Process traces of cells to extract peak information and organize the data.
144
+
145
+ The function normalizes fluorescence data, finds peaks, refines peak information,
146
+ removes unwanted peaks that might be due to close proximity of bright peaks from
147
+ other paths, and organizes all the information into a structured data format.
148
+
149
+ Args:
150
+ all_paths (list of list of tuples): A list containing paths, where each path is
151
+ represented as a list of 3D coordinate tuples.
152
+ path_lengths (list of float): List of path lengths corresponding to the provided paths.
153
+ measured_trace_fluorescence (list of list of float): A list containing fluorescence
154
+ data corresponding to each path point.
155
+ dmin (float): Distance below which brighter peaks screen less bright ones.
156
+
157
+ Returns:
158
+ CellData: An object containing organized peak and path data for a given cell.
159
+
160
+ Note:
161
+ - The function assumes that each path and its corresponding length and fluorescence data
162
+ are positioned at the same index in their respective lists.
163
+ """
164
+
165
+ cell_peaks = []
166
+
167
+ for points, o_intensity in zip(all_paths, measured_trace_fluorescence):
168
+
169
+ # For peak determination normalize each trace to have mean zero and s.d. 1
170
+ intensity_normalized = (o_intensity - np.mean(o_intensity))/np.std(o_intensity)
171
+
172
+ # Find peaks - these will be further refined later
173
+ p,_ = find_peaks2(intensity_normalized, distance=5, prominence=0.5*np.std(intensity_normalized))
174
+ peaks = np.array(p, dtype=np.int32)
175
+
176
+ # Store peak data - using original values, not normalized ones
177
+ peak_mean_heights = [ o_intensity[u] for u in peaks ]
178
+ peak_points = [ points[u] for u in peaks ]
179
+
180
+ cell_peaks.append((peaks, peak_points, peak_mean_heights))
181
+
182
+ # Eliminate peaks which have another larger peak nearby (in 3D space, on any chromosome).
183
+ # This aims to remove small peaks in the mean intensity generated when an SC passes close
184
+ # to a bright peak on another SC - this is nearby in space, but brighter.
185
+
186
+ to_thin = []
187
+ for k in range(len(cell_peaks)):
188
+ for u in range(len(cell_peaks[k][0])):
189
+ to_thin.append(PeakData(pos=cell_peaks[k][1][u], intensity=cell_peaks[k][2][u], key=(k, u)))
190
+
191
+ # Exclude any peak with a nearby brighter peak (on any SC)
192
+ removed_peaks, removed_larger_peaks = thin_peaks(to_thin, return_larger_peaks=True, dmin=dmin)
193
+
194
+ # Clean up and remove these peaks
195
+ new_cell_peaks = []
196
+ removed_cell_peaks = []
197
+ removed_cell_peaks_larger = []
198
+ for path_idx in range(len(cell_peaks)):
199
+ path_retained_peaks = []
200
+ path_removed_peaks = []
201
+ path_peaks = cell_peaks[path_idx][0]
202
+
203
+ for peak_idx in range(len(path_peaks)):
204
+ if (path_idx, peak_idx) not in removed_peaks:
205
+ path_retained_peaks.append(path_peaks[peak_idx])
206
+ else:
207
+ # What's the larger point?
208
+ idx = removed_peaks.index((path_idx, peak_idx))
209
+ larger_path, larger_idx = removed_larger_peaks[idx]
210
+ path_removed_peaks.append(RemovedPeakData(idx=path_peaks[peak_idx], screening_peak=(larger_path, cell_peaks[larger_path][0][larger_idx])))
211
+ ###
212
+
213
+ new_cell_peaks.append(path_retained_peaks)
214
+ removed_cell_peaks.append(path_removed_peaks)
215
+
216
+ cell_peaks = new_cell_peaks
217
+ pd_list = []
218
+
219
+ # Save peak positions, absolute intensity intensities, and length for each SC
220
+ for k in range(len(all_paths)):
221
+
222
+ points, o_intensity = all_paths[k], measured_trace_fluorescence[k]
223
+
224
+ peaks = cell_peaks[k]
225
+ removed_peaks = removed_cell_peaks[k]
226
+
227
+ pd = PathData(peaks=peaks, removed_peaks=removed_peaks, points=points, o_intensity=o_intensity, SC_length=path_lengths[k])
228
+ pd_list.append(pd)
229
+
230
+ cd = CellData(pathdata_list=pd_list)
231
+
232
+ return cd
233
+
234
+
235
+ alpha_max = 0.4
236
+
237
+
238
+ # Criterion used for identifying peak as a focus - normalized (with mean and s.d.)
239
+ # intensity levels being above 0.4 time maximum peak level
240
+ def focus_criterion(pos, v, alpha=alpha_max):
241
+ """
242
+ Identify and return positions where values in the array `v` exceed a certain threshold.
243
+
244
+ The threshold is computed as `alpha` times the maximum value in `v`.
245
+
246
+ Args:
247
+ - pos (numpy.ndarray): Array of positions.
248
+ - v (numpy.ndarray): 1D array of values, e.g., intensities.
249
+ - alpha (float, optional): A scaling factor for the threshold. Defaults to `alpha_max`.
250
+
251
+ Returns:
252
+ - numpy.ndarray: Array of positions where corresponding values in `v` exceed the threshold.
253
+ """
254
+ if len(v):
255
+ idx = (v>=alpha*np.max(v))
256
+ return np.array(pos[idx])
257
+ else:
258
+ return np.array([], dtype=np.int32)
259
+
260
+ def analyse_celldata(cell_data, config):
261
+ """
262
+ Analyse the provided cell data to extract focus-related information.
263
+
264
+ Args:
265
+ cd (CellData): An instance of the CellData class containing path data information.
266
+ config (dictionary): Configuration dictionary containing 'peak_threshold' and 'threshold_type'
267
+ 'peak_threshold' (float) - threshold for calling peaks as foci
268
+ 'threshold_type' (str) = 'per-trace', 'per-foci'
269
+
270
+ Returns:
271
+ tuple: A tuple containing:
272
+ - foci_rel_intensity (list): List of relative intensities for the detected foci.
273
+ - foci_pos (list): List of absolute positions of the detected foci.
274
+ - foci_pos_index (list): List of indices of the detected foci.
275
+ - screened_foci_data (list): List of RemovedPeakData indicating positions of removed peaks and the index of the larger peak
276
+ - trace_median_intensities (list): Per-trace median intensity
277
+ - trace_thresholds (list): Per-trace absolute threshold for calling peaks as foci
278
+ """
279
+ foci_abs_intensity = []
280
+ foci_pos = []
281
+ foci_pos_index = []
282
+ screened_foci_data = []
283
+ trace_median_intensities = []
284
+ trace_thresholds = []
285
+
286
+ peak_threshold = config['peak_threshold']
287
+
288
+ threshold_type = config['threshold_type']
289
+
290
+ if threshold_type == 'per-trace':
291
+ """
292
+ Call extracted peaks as foci if intensity - trace_mean > peak_threshold * (trace_max_foci_intensity - trace_mean)
293
+ """
294
+
295
+ for path_data in cell_data.pathdata_list:
296
+ peaks = np.array(path_data.peaks, dtype=np.int32)
297
+
298
+ # Normalize extracted fluorescent intensities by subtracting mean (and dividing
299
+ # by standard deviation - note that the latter should have no effect on the results).
300
+ h = np.array(path_data.o_intensity)
301
+ h = h - np.mean(h)
302
+ h = h/np.std(h)
303
+ # Extract foci according to criterion
304
+ foci_idx = focus_criterion(peaks, h[peaks], peak_threshold)
305
+
306
+ #
307
+ removed_peaks = path_data.removed_peaks
308
+ removed_peaks_idx = np.array([u.idx for u in removed_peaks], dtype=np.int32)
309
+
310
+
311
+ if len(peaks):
312
+ trace_thresholds.append((1-peak_threshold)*np.mean(path_data.o_intensity) + peak_threshold*np.max(np.array(path_data.o_intensity)[peaks]))
313
+ else:
314
+ trace_thresholds.append(None)
315
+
316
+ if len(removed_peaks):
317
+ if len(peaks):
318
+ threshold = (1-peak_threshold)*np.mean(path_data.o_intensity) + peak_threshold*np.max(np.array(path_data.o_intensity)[peaks])
319
+ else:
320
+ threshold = float('-inf')
321
+
322
+
323
+ removed_peak_heights = np.array(path_data.o_intensity)[removed_peaks_idx]
324
+ screened_foci_idx = np.where(removed_peak_heights>threshold)[0]
325
+
326
+ screened_foci_data.append([removed_peaks[i] for i in screened_foci_idx])
327
+ else:
328
+ screened_foci_data.append([])
329
+
330
+ pos_abs = (foci_idx/len(path_data.points))*path_data.SC_length
331
+ foci_pos.append(pos_abs)
332
+ foci_abs_intensity.append(np.array(path_data.o_intensity)[foci_idx])
333
+
334
+ foci_pos_index.append(foci_idx)
335
+ trace_median_intensities.append(np.median(path_data.o_intensity))
336
+
337
+ elif threshold_type == 'per-cell':
338
+ """
339
+ Call extracted peaks as foci if intensity - trace_mean > peak_threshold * max(intensity - trace_mean)
340
+ """
341
+ max_cell_intensity = float("-inf")
342
+ for path_data in cell_data.pathdata_list:
343
+
344
+ # Normalize extracted fluorescent intensities by subtracting mean (and dividing
345
+ # by standard deviation - note that the latter should have no effect on the results).
346
+ h = np.array(path_data.o_intensity)
347
+ h = h - np.mean(h)
348
+ max_cell_intensity = max(max_cell_intensity, np.max(h))
349
+
350
+ for path_data in cell_data.pathdata_list:
351
+ peaks = np.array(path_data.peaks, dtype=np.int32)
352
+
353
+ # Normalize extracted fluorescent intensities by subtracting mean (and dividing
354
+ # by standard deviation - note that the latter should have no effect on the results).
355
+ h = np.array(path_data.o_intensity)
356
+ h = h - np.mean(h)
357
+
358
+ foci_idx = peaks[h[peaks]>peak_threshold*max_cell_intensity]
359
+
360
+ removed_peaks = path_data.removed_peaks
361
+ removed_peaks_idx = np.array([u.idx for u in removed_peaks], dtype=np.int32)
362
+
363
+ trace_thresholds.append(np.mean(path_data.o_intensity) + peak_threshold*max_cell_intensity)
364
+
365
+ if len(removed_peaks):
366
+ threshold = np.mean(path_data.o_intensity) + peak_threshold*max_cell_intensity
367
+
368
+ removed_peak_heights = np.array(path_data.o_intensity)[removed_peaks_idx]
369
+ screened_foci_idx = np.where(removed_peak_heights>threshold)[0]
370
+
371
+ screened_foci_data.append([removed_peaks[i] for i in screened_foci_idx])
372
+ else:
373
+ screened_foci_data.append([])
374
+
375
+ pos_abs = (foci_idx/len(path_data.points))*path_data.SC_length
376
+ foci_pos.append(pos_abs)
377
+ foci_abs_intensity.append(np.array(path_data.o_intensity)[foci_idx])
378
+
379
+ foci_pos_index.append(foci_idx)
380
+ trace_median_intensities.append(np.median(path_data.o_intensity))
381
+
382
+ else:
383
+ raise NotImplementedError
384
+
385
+ return foci_abs_intensity, foci_pos, foci_pos_index, screened_foci_data, trace_median_intensities, trace_thresholds
386
+
387
+ def analyse_traces(all_paths, path_lengths, measured_trace_fluorescence, config):
388
+
389
+ cd = process_cell_traces(all_paths, path_lengths, measured_trace_fluorescence, dmin=config['screening_distance'])
390
+
391
+ return analyse_celldata(cd, config)
392
+
393
+
394
+
395
+
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ pillow
2
+ tifffile
3
+ matplotlib
4
+ numpy
5
+ lxml
6
+ pandas
7
+ scipy
setup.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name='path_analysis',
5
+ version='0.1.0',
6
+ description='A brief description of your package',
7
+ author='Your Name',
8
+ author_email='youremail@example.com',
9
+ url='https://github.com/yourusername/yourrepository', # if you have a repo for the project
10
+ packages=find_packages(), # or specify manually: ['your_package', 'your_package.submodule', ...]
11
+ install_requires=[
12
+ 'numpy', # for example, if your package needs numpy
13
+ 'gradio',
14
+ # ... other dependencies
15
+ ],
16
+ classifiers=[
17
+ 'Development Status :: 3 - Alpha',
18
+ 'Intended Audience :: Developers',
19
+ 'Programming Language :: Python :: 3',
20
+ 'Programming Language :: Python :: 3.6',
21
+ 'Programming Language :: Python :: 3.7',
22
+ 'Programming Language :: Python :: 3.8',
23
+ 'Programming Language :: Python :: 3.9',
24
+ # ... other classifiers
25
+ ],
26
+ python_requires='>=3.6', # your project's Python version requirement
27
+ keywords='some keywords related to your project',
28
+ # ... other parameters
29
+ )
tests/__init__.py ADDED
File without changes
tests/test_analyse.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import pytest
3
+ from path_analysis.analyse import *
4
+ from path_analysis.data_preprocess import RemovedPeakData
5
+ import numpy as np
6
+ from math import pi
7
+ import xml.etree.ElementTree as ET
8
+ from PIL import ImageChops
9
+
10
+ def test_draw_paths_no_error():
11
+ all_paths = [[[0, 0], [1, 1]], [[2, 2], [3, 3]]]
12
+ foci_stack = np.zeros((5, 5, 5))
13
+ foci_stack[0,0,0] = 1.0
14
+ foci_index = [[0], [1]]
15
+ r = 3
16
+
17
+ try:
18
+ im = draw_paths(all_paths, foci_stack, foci_index, r)
19
+ except Exception as e:
20
+ pytest.fail(f"draw_paths raised an exception: {e}")
21
+
22
+ def test_draw_paths_image_size():
23
+ all_paths = [[[0, 0], [1, 1]], [[2, 2], [3, 3]]]
24
+ foci_stack = np.zeros((5, 5, 5))
25
+ foci_stack[0,0,0] = 1.0
26
+
27
+ foci_index = [[0], [1]]
28
+ r = 3
29
+
30
+ im = draw_paths(all_paths, foci_stack, foci_index, r)
31
+ assert im.size == (5, 5), f"Expected image size (5, 5), got {im.size}"
32
+
33
+ def test_draw_paths_image_modified():
34
+ all_paths = [[[0, 0], [1, 1]], [[2, 2], [3, 3]]]
35
+ foci_stack = np.zeros((5, 5, 5))
36
+ foci_stack[0,0,0] = 1.0
37
+ foci_index = [[0], [1]]
38
+ r = 3
39
+
40
+ im = draw_paths(all_paths, foci_stack, foci_index, r)
41
+ blank_image = Image.new("RGB", (5, 5), "black")
42
+
43
+ # Check if the image is not entirely black (i.e., has been modified)
44
+ diff = ImageChops.difference(im, blank_image)
45
+ assert diff.getbbox() is not None, "The image has not been modified"
46
+
47
+
48
+
49
+ def test_calculate_path_length_partials_default_voxel():
50
+ point_list = [(0, 0, 0), (1, 0, 0), (1, 1, 1)]
51
+ expected_result = np.array([0.0, 1.0, 1.0+np.sqrt(2)])
52
+ result = calculate_path_length_partials(point_list)
53
+ np.testing.assert_allclose(result, expected_result, atol=1e-5)
54
+
55
+ def test_calculate_path_length_partials_custom_voxel():
56
+ point_list = [(0, 0, 0), (1, 0, 0), (1, 1, 0)]
57
+ voxel_size = (1, 2, 1)
58
+ expected_result = np.array([0.0, 1.0, 3.0])
59
+ result = calculate_path_length_partials(point_list, voxel_size=voxel_size)
60
+ np.testing.assert_allclose(result, expected_result, atol=1e-5)
61
+
62
+ def test_calculate_path_length_partials_single_point():
63
+ point_list = [(0, 0, 0)]
64
+ expected_result = np.array([0.0])
65
+ result = calculate_path_length_partials(point_list)
66
+ np.testing.assert_allclose(result, expected_result, atol=1e-5)
67
+
68
+ def test_get_paths_from_traces_file():
69
+ # Mock the XML traces file content
70
+ xml_content = '''<?xml version="1.0"?>
71
+ <root>
72
+ <path reallength="5.0">
73
+ <point x="1" y="2" z="3"/>
74
+ <point x="4" y="5" z="6"/>
75
+ </path>
76
+ <path reallength="10.0">
77
+ <point x="7" y="8" z="9"/>
78
+ <point x="10" y="11" z="12"/>
79
+ </path>
80
+ </root>
81
+ '''
82
+
83
+ # Create a temporary XML file
84
+ with open("temp_traces.xml", "w") as f:
85
+ f.write(xml_content)
86
+
87
+ all_paths, path_lengths = get_paths_from_traces_file("temp_traces.xml")
88
+
89
+ expected_paths = [[(1, 2, 3), (4, 5, 6)], [(7, 8, 9), (10, 11, 12)]]
90
+ expected_lengths = [5.0, 10.0]
91
+
92
+ assert all_paths == expected_paths, f"Expected paths {expected_paths}, but got {all_paths}"
93
+ assert path_lengths == expected_lengths, f"Expected lengths {expected_lengths}, but got {path_lengths}"
94
+
95
+ # Clean up temporary file
96
+ import os
97
+ os.remove("temp_traces.xml")
98
+
99
+
100
+ def test_measure_chrom2():
101
+ # Mock data
102
+ path = [(2, 3, 4), (4, 5, 6), (9, 9, 9)] # Sample ordered path points
103
+ intensity = np.random.rand(10, 10, 10) # Random 3D fluorescence data
104
+ config = {
105
+ 'z_res': 1,
106
+ 'xy_res': 0.5,
107
+ 'sphere_radius': 2.5
108
+ }
109
+
110
+ # Function call
111
+ _, measurements, measurements_max = measure_chrom2(path, intensity, config)
112
+
113
+ # Assertions
114
+ assert len(measurements) == len(path), "Measurements length should match path length"
115
+ assert len(measurements_max) == len(path), "Max measurements length should match path length"
116
+ assert all(0 <= val <= 1 for val in measurements), "All mean measurements should be between 0 and 1 for this mock data"
117
+ assert all(0 <= val <= 1 for val in measurements_max), "All max measurements should be between 0 and 1 for this mock data"
118
+
119
+ def test_measure_chrom2_z():
120
+ # Mock data
121
+ path = [(2, 3, 4), (4, 5, 6)] # Sample ordered path points
122
+ _,_,intensity = np.meshgrid(np.arange(10), np.arange(10), np.arange(10)) # 3D fluorescence data - z dependent
123
+ config = {
124
+ 'z_res': 1,
125
+ 'xy_res': 0.5,
126
+ 'sphere_radius': 2.5
127
+ }
128
+
129
+ # Function call
130
+ _, measurements, measurements_max = measure_chrom2(path, intensity, config)
131
+
132
+ # Assertions
133
+ assert len(measurements) == len(path), "Measurements length should match path length"
134
+ assert len(measurements_max) == len(path), "Max measurements length should match path length"
135
+ assert all(measurements == np.array([4,6]))
136
+ assert all(measurements_max == np.array([6,8]))
137
+
138
+ def test_measure_chrom2_z2():
139
+ # Mock data
140
+ path = [(0,0,0), (2, 3, 4), (4, 5, 6)] # Sample ordered path points
141
+ _,_,intensity = np.meshgrid(np.arange(10), np.arange(10), np.arange(10)) # 3D fluorescence data - z dependent
142
+ config = {
143
+ 'z_res': 0.25,
144
+ 'xy_res': 0.5,
145
+ 'sphere_radius': 2.5
146
+ }
147
+
148
+ # Function call
149
+ _, measurements, measurements_max = measure_chrom2(path, intensity, config)
150
+
151
+ # Assertions
152
+ assert len(measurements) == len(path), "Measurements length should match path length"
153
+ assert len(measurements_max) == len(path), "Max measurements length should match path length"
154
+ assert all(measurements_max == np.array([9,9,9]))
155
+
156
+
157
+ def test_measure_from_mask():
158
+ mask = np.array([
159
+ [0, 1, 0],
160
+ [1, 1, 1],
161
+ [0, 1, 0]
162
+ ])
163
+ measure_stack = np.array([
164
+ [2, 4, 2],
165
+ [4, 8, 4],
166
+ [2, 4, 2]
167
+ ])
168
+ result = measure_from_mask(mask, measure_stack)
169
+ assert result == 24 # Expected sum: 4+4+8+4+4
170
+
171
+ def test_max_from_mask():
172
+ mask = np.array([
173
+ [0, 1, 0],
174
+ [1, 1, 1],
175
+ [0, 1, 0]
176
+ ])
177
+ measure_stack = np.array([
178
+ [2, 5, 2],
179
+ [4, 8, 3],
180
+ [2, 7, 2]
181
+ ])
182
+ result = max_from_mask(mask, measure_stack)
183
+ assert result == 8 # Expected max: 8
184
+
185
+
186
+ def test_measure_at_point_mean():
187
+ measure_stack = np.array([
188
+ [[2, 2, 2, 0], [4, 4, 6, 0], [3, 3, 2, 0], [0, 0, 0, 0]],
189
+ [[4, 4, 4, 0], [8, 8, 8, 0], [4, 4, 4, 0], [0, 0, 0, 0]],
190
+ [[3, 3, 3, 0], [6, 6, 4, 0], [3, 2, 2, 0], [0, 0, 0, 0]],
191
+ [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
192
+ ])
193
+ p = (1, 1, 1)
194
+ melem = np.ones((3, 3, 3))
195
+ result = measure_at_point(p, melem, measure_stack, op='mean')
196
+ assert result == 4, "Expected mean: 4"
197
+
198
+ def test_measure_at_point_mean_off1():
199
+ measure_stack = np.array([
200
+ [[2, 2, 2, 0], [4, 4, 6, 0], [5, 5, 2, 0], [0, 0, 0, 0]],
201
+ [[4, 4, 4, 0], [8, 8, 8, 0], [4, 4, 4, 0], [0, 0, 0, 0]],
202
+ [[3, 3, 3, 0], [6, 6, 4, 0], [3, 2, 2, 0], [0, 0, 0, 0]],
203
+ [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
204
+ ])
205
+ p = (0, 0, 0)
206
+ melem = np.ones((3, 3, 3))
207
+ result = measure_at_point(p, melem, measure_stack, op='mean')
208
+ assert result == 4.5, "Expected mean: 4.5"
209
+
210
+ def test_measure_at_point_mean_off2():
211
+ measure_stack = np.array([
212
+ [[2, 2, 2, 0], [4, 4, 6, 0], [5, 5, 2, 0], [0, 0, 0, 0]],
213
+ [[4, 4, 4, 0], [8, 8, 8, 0], [4, 4, 4, 0], [0, 0, 0, 0]],
214
+ [[3, 3, 3, 0], [6, 6, 4, 0], [3, 2, 2, 0], [0, 0, 0, 0]],
215
+ [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
216
+ ])
217
+ p = (3, 1, 1)
218
+ melem = np.ones((3, 3, 3))
219
+ print(measure_stack[p[0], p[1], p[2]])
220
+
221
+ result = measure_at_point(p, melem, measure_stack, op='mean')
222
+ assert result == 32/18 # Expected mean: 4.5
223
+
224
+ def test_measure_at_point_mean_off3():
225
+ measure_stack = np.array([
226
+ [[2, 2, 2, 0], [4, 4, 6, 0], [5, 5, 2, 0], [0, 0, 0, 0]],
227
+ [[4, 4, 4, 0], [8, 8, 8, 0], [4, 4, 4, 0], [0, 0, 0, 0]],
228
+ [[3, 3, 3, 0], [6, 6, 4, 0], [3, 2, 2, 0], [0, 0, 0, 0]],
229
+ [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
230
+ ])
231
+ p = (3, 1, 1)
232
+ melem = np.ones((1, 1, 3))
233
+ print(measure_stack[p[0], p[1], p[2]])
234
+
235
+ result = measure_at_point(p, melem, measure_stack, op='mean')
236
+ assert result == 0, "Expected mean: 4.5"
237
+
238
+ def test_measure_at_point_mean_off3():
239
+ measure_stack = np.array([
240
+ [[2, 2, 2, 0], [4, 4, 6, 0], [5, 5, 2, 0], [0, 0, 0, 0]],
241
+ [[4, 4, 4, 0], [8, 8, 8, 0], [4, 4, 4, 0], [0, 0, 0, 0]],
242
+ [[3, 3, 3, 0], [6, 6, 4, 0], [3, 2, 2, 0], [0, 0, 0, 0]],
243
+ [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
244
+ ])
245
+ p = (3, 1, 1)
246
+ melem = np.ones((3, 1, 1))
247
+ print(measure_stack[p[0], p[1], p[2]])
248
+
249
+ result = measure_at_point(p, melem, measure_stack, op='mean')
250
+ assert result == 3, "Expected mean: 4.5"
251
+
252
+
253
+ def test_measure_at_point_max():
254
+ measure_stack = np.array([
255
+ [[2, 2, 2], [4, 4, 4], [2, 2, 2]],
256
+ [[4, 5, 4], [8, 7, 9], [4, 4, 4]],
257
+ [[2, 2, 2], [4, 4, 4], [2, 2, 2]]
258
+ ])
259
+ p = (1, 1, 1)
260
+ melem = np.ones((3, 3, 3))
261
+ result = measure_at_point(p, melem, measure_stack, op='max')
262
+ assert result == 9, "Expected max: 9"
263
+
264
+
265
+ def test_make_sphere_equal():
266
+ R = 5
267
+ z_scale_ratio = 1.0
268
+
269
+ sphere = make_sphere(R, z_scale_ratio)
270
+
271
+ # Check the returned type
272
+ assert isinstance(sphere, np.ndarray), "Output should be a numpy ndarray"
273
+
274
+ # Check the shape
275
+ expected_shape = (2*R+1, 2*R+1, 2*R+1)
276
+ assert sphere.shape == expected_shape, f"Expected shape {expected_shape}, but got {sphere.shape}"
277
+
278
+ assert (sphere[:,:,::-1] == sphere).all(), f"Expected symmetrical mask"
279
+ assert (sphere[:,::-1,:] == sphere).all(), f"Expected symmetrical mask"
280
+ assert (sphere[::-1,:,:] == sphere).all(), f"Expected symmetrical mask"
281
+ assert abs(np.sum(sphere)-4/3*pi*R**3)<10, f"Expected approximate volume to be correct"
282
+ assert (sphere[R,R,0] == 1), f"Expected centre point on top plane to be within sphere"
283
+ assert (sphere[R+1,R,0] == 0), f"Expected point next to centre on top plane to be outside sphere"
284
+
285
+ import pandas as pd
286
+
287
+ def test_extract_peaks_basic():
288
+ cell_id = 1 # Simple per-cell tag
289
+ all_paths = [[[0, 0, 0], [1, 1, 0]]] # Single, simple path
290
+ path_lengths = [1.41] # length of the above path
291
+ measured_traces = [[100, 200]] # fluorescence along the path
292
+ config = {'peak_threshold': 0.4, 'sphere_radius': 2, 'xy_res': 1, 'z_res': 1, 'threshold_type':'per-cell', 'use_corrected_positions': True, 'screening_distance':10 }
293
+
294
+ df, foci_absolute_intensity, foci_pos_index, screened_foci_data, trace_thresholds, trace_positions = extract_peaks(cell_id, all_paths, path_lengths, measured_traces, config)
295
+
296
+ assert len(df) == 1, "Expected one row in DataFrame"
297
+ assert df['Cell_ID'].iloc[0] == cell_id, "Unexpected cell_id"
298
+ assert list(df['Trace_foci_number']) == [1], "Wrong foci number"
299
+ assert df['Foci_1_position(um)'].iloc[0] == np.sqrt(2)
300
+ assert foci_pos_index == [[1]]
301
+ assert foci_absolute_intensity == [[200]]
302
+ assert screened_foci_data == [[]]
303
+ assert trace_thresholds == [ [ 150+0.4*50] ]
304
+ assert np.all(trace_positions[0] == np.array([0, np.sqrt(2)]))
305
+
306
+ def test_extract_peaks_multiple_paths():
307
+ cell_id = 1
308
+ all_paths = [[[0, 0, 0], [1, 1, 0]], [[1, 1, 200], [2, 2, 200]]]
309
+ path_lengths = [1.41, 1.41]
310
+ measured_traces = [[100, 200], [100, 140]]
311
+ config = {'peak_threshold': 0.4, 'sphere_radius': 2, 'xy_res': 1, 'z_res': 1, 'threshold_type':'per-trace', 'use_corrected_positions': True, 'screening_distance':10 }
312
+
313
+ df, foci_absolute_intensity, foci_pos_index, screened_foci_data, trace_thresholds, trace_positions = extract_peaks(cell_id, all_paths, path_lengths, measured_traces, config)
314
+
315
+
316
+
317
+ assert len(df) == 2, "Expected two rows in DataFrame"
318
+ assert df['Cell_ID'].iloc[0] == cell_id, "Unexpected cell_id"
319
+ assert list(df['Trace_foci_number']) == [1,1], "Wrong foci number"
320
+ assert df['Foci_1_position(um)'].iloc[0] == np.sqrt(2)
321
+ print(foci_pos_index)
322
+ assert list(map(list, foci_pos_index)) == [[1],[1]]
323
+ assert list(map(list, foci_absolute_intensity)) == [[200],[140]]
324
+ assert trace_thresholds == [ 150+0.4*50, 120+0.4*20 ]
325
+ assert np.all(trace_positions[0] == np.array([0, np.sqrt(2)]))
326
+ assert screened_foci_data == [[],[]]
327
+
328
+ def test_extract_peaks_multiple_paths_screened():
329
+ cell_id = 1
330
+ all_paths = [[[0, 0, 0], [1, 1, 0]], [[1, 1, 2], [2, 2, 2]]]
331
+ path_lengths = [1.41, 1.41]
332
+ measured_traces = [[100, 200], [100, 150]]
333
+ config = {'peak_threshold': 0.4, 'sphere_radius': 2, 'xy_res': 1, 'z_res': 1, 'threshold_type':'per-trace', 'use_corrected_positions': True, 'screening_distance':10 }
334
+
335
+ df, foci_absolute_intensity, foci_pos_index, screened_foci_data, trace_thresholds, trace_positions = extract_peaks(cell_id, all_paths, path_lengths, measured_traces, config)
336
+
337
+
338
+
339
+ assert len(df) == 2, "Expected two rows in DataFrame"
340
+ assert df['Cell_ID'].iloc[0] == cell_id, "Unexpected cell_id"
341
+ assert list(df['Trace_foci_number']) == [1,0], "Wrong foci number"
342
+ assert df['Foci_1_position(um)'].iloc[0] == np.sqrt(2)
343
+ print(foci_pos_index)
344
+ assert list(map(list, foci_pos_index)) == [[1],[]]
345
+ assert list(map(list, foci_absolute_intensity)) == [[200],[]]
346
+ assert trace_thresholds == [ 150+0.4*50, None ]
347
+ assert np.all(trace_positions[0] == np.array([0, np.sqrt(2)]))
348
+ assert screened_foci_data == [[],[RemovedPeakData(idx=1, screening_peak=(0,1))]]
349
+
350
+
351
+ def test_extract_peaks_multiple_paths_per_cell():
352
+ cell_id = 1
353
+ all_paths = [[[0, 0, 0], [1, 1, 0]], [[1, 1, 200], [2, 2, 200]]]
354
+ path_lengths = [1.41, 1.41]
355
+ measured_traces = [[100, 200], [100, 140]]
356
+ config = {'peak_threshold': 0.4, 'sphere_radius': 2, 'xy_res': 1, 'z_res': 1, 'threshold_type':'per-cell', 'use_corrected_positions': True, 'screening_distance':10 }
357
+
358
+ df, foci_absolute_intensity, foci_pos_index, screened_foci_data, trace_thresholds, trace_positions = extract_peaks(cell_id, all_paths, path_lengths, measured_traces, config)
359
+
360
+
361
+
362
+ assert len(df) == 2, "Expected two rows in DataFrame"
363
+ assert df['Cell_ID'].iloc[0] == cell_id, "Unexpected cell_id"
364
+ assert list(df['Trace_foci_number']) == [1,0], "Wrong foci number"
365
+ assert df['Foci_1_position(um)'].iloc[0] == np.sqrt(2)
366
+ assert list(map(list, foci_pos_index)) == [[1],[]]
367
+ assert list(map(list, foci_absolute_intensity)) == [[200],[]]
368
+ assert trace_thresholds == [ 150+0.4*50, 120+0.4*50 ]
369
+ assert np.all(trace_positions[0] == np.array([0, np.sqrt(2)]))
370
+ assert screened_foci_data == [[],[]]
tests/test_preprocess.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from path_analysis.data_preprocess import *
2
+ import numpy as np
3
+ import pytest
4
+
5
+
6
+ def test_thin_points():
7
+ # Define a sample point list
8
+ points = [
9
+ PeakData([0, 0, 0], 10, 0),
10
+ PeakData([1, 1, 1], 8, 1),
11
+ PeakData([10, 10, 10], 12, 2),
12
+ PeakData([10.5, 10.5, 10.5], 5, 3),
13
+ PeakData([20, 20, 20], 15, 4)
14
+ ]
15
+
16
+ # Call the thin_points function with dmin=5 (for example)
17
+ removed_indices = thin_peaks(points, dmin=5)
18
+
19
+ # Check results
20
+ # Point at index 1 ([1, 1, 1]) should be removed since it's within 5 units distance of point at index 0 and has lower intensity.
21
+ # Similarly, point at index 3 ([10.5, 10.5, 10.5]) should be removed as it's close to point at index 2 and has lower intensity.
22
+ assert set(removed_indices) == {1, 3}
23
+
24
+ # Another simple test to check if function does nothing when points are far apart
25
+ far_points = [
26
+ PeakData([0, 0, 0], 10, 0),
27
+ PeakData([100, 100, 100], 12, 1),
28
+ PeakData([200, 200, 200], 15, 2)
29
+ ]
30
+
31
+ removed_indices_far = thin_peaks(far_points, dmin=5)
32
+ assert len(removed_indices_far) == 0 # Expect no points to be removed
33
+
34
+
35
+ def test_find_peaks2():
36
+
37
+ # Basic test
38
+ data = np.array([0, 0, 0, 0, 0, 0, 5, 0, 3, 0])
39
+ peaks, _ = find_peaks2(data)
40
+ assert set(peaks) == {6} # Expected peaks at positions 6
41
+
42
+ # Basic test
43
+ data = np.array([0, 2, 0, 0, 0, 0, 0, 0, 0, 0])
44
+ peaks, _ = find_peaks2(data)
45
+ assert set(peaks) == {1} # Expected peaks at positions 1
46
+
47
+
48
+ # Test with padding impacting peak detection
49
+ data = np.array([3, 2.9, 0, 0, 0, 3])
50
+ peaks, _ = find_peaks2(data)
51
+ assert set(peaks) == {0,5} # Peaks at both ends
52
+
53
+ # Test with close peaks
54
+ data = np.array([3, 0, 3])
55
+ peaks, _ = find_peaks2(data)
56
+ assert set(peaks) == {2} # Peak at right end only
57
+ # Test with close peaks
58
+
59
+
60
+ # Test with close peaks
61
+ data = np.array([3, 0, 3])
62
+ peaks, _ = find_peaks2(data, distance=1)
63
+ assert set(peaks) == {0,2} # Peaks at both ends
64
+
65
+ # Test with close peaks
66
+ data = np.array([0, 3, 3, 3, 0, 3, 3, 3, 3, 3, 3])
67
+ peaks, _ = find_peaks2(data, distance=1)
68
+ assert set(peaks) == {2,7} # Peak at centre (rounded to the left) of groups of maximum values
69
+
70
+ # Test with prominence threshold
71
+ data = np.array([0, 1, 0, 0.4, 0])
72
+ peaks, _ = find_peaks2(data, prominence=0.5)
73
+ assert peaks == [1] # Only the peak at position 1 meets the prominence threshold
74
+
75
+
76
+ def test_focus_criterion():
77
+ pos = np.array([0, 1, 2, 3, 4, 6])
78
+ values = np.array([0.1, 0.5, 0.2, 0.8, 0.3, 0.9])
79
+
80
+ # Basic test
81
+ assert np.array_equal(focus_criterion(pos, values), np.array([1, 3, 6])) # only values 0.8 and 0.9 exceed 0.4 times the max (which is 0.9)
82
+
83
+ # Empty test
84
+ assert np.array_equal(focus_criterion(np.array([]), np.array([])), np.array([]))
85
+
86
+ # Test with custom alpha
87
+ assert np.array_equal(focus_criterion(pos, values, alpha=0.5), np.array([1, 3, 6]))
88
+
89
+ # Test with a larger alpha
90
+ assert np.array_equal(focus_criterion(pos, values, alpha=1.0), [6]) # No values exceed the maximum value itself
91
+
92
+ # Test with all values below threshold
93
+ values = np.array([0.1, 0.2, 0.3, 0.4])
94
+
95
+ assert np.array_equal(focus_criterion(pos[:4], values), [1,2,3]) # All values are below 0.4 times the max (which is 0.4)
96
+
97
+ @pytest.fixture
98
+ def mock_data():
99
+ all_paths = [ [ (0,0,0), (0,2,0), (0,5,0), (0,10,0), (0,15,0), (0,20,0)], [ (1,20,0), (1,20,10), (1,20,20) ] ] # Mock paths
100
+ path_lengths = [ 2.2, 2.3 ] # Mock path lengths
101
+ measured_trace_fluorescence = [ [100, 8, 3, 2, 3, 49], [38, 2, 20] ] # Mock fluorescence data
102
+ return all_paths, path_lengths, measured_trace_fluorescence
103
+
104
+ def test_process_cell_traces_return_type(mock_data):
105
+ all_paths, path_lengths, measured_trace_fluorescence = mock_data
106
+ result = process_cell_traces(all_paths, path_lengths, measured_trace_fluorescence)
107
+ assert isinstance(result, CellData), f"Expected CellData but got {type(result)}"
108
+
109
+ def test_process_cell_traces_pathdata_list_length(mock_data):
110
+ all_paths, path_lengths, measured_trace_fluorescence = mock_data
111
+ result = process_cell_traces(all_paths, path_lengths, measured_trace_fluorescence)
112
+ assert len(result.pathdata_list) == len(all_paths), f"Expected {len(all_paths)} but got {len(result.pathdata_list)}"
113
+
114
+ def test_process_cell_traces_pathdata_path_lengths(mock_data):
115
+ all_paths, path_lengths, measured_trace_fluorescence = mock_data
116
+ result = process_cell_traces(all_paths, path_lengths, measured_trace_fluorescence)
117
+ path_lengths = [p.SC_length for p in result.pathdata_list]
118
+ expected_path_lengths = [2.2, 2.3]
119
+ assert path_lengths == expected_path_lengths, f"Expected {expected_path_lengths} but got {path_lengths}"
120
+
121
+ def test_process_cell_traces_peaks(mock_data):
122
+ all_paths, path_lengths, measured_trace_fluorescence = mock_data
123
+ result = process_cell_traces(all_paths, path_lengths, measured_trace_fluorescence)
124
+ print(result)
125
+ peaks = [p.peaks for p in result.pathdata_list]
126
+ assert peaks == [[0,5],[]]
127
+
128
+ # Mock data
129
+ @pytest.fixture
130
+ def mock_celldata():
131
+ pathdata1 = PathData(peaks=[0, 5], points=[(0,0,0), (0,2,0), (0,5,0), (0,10,0), (0,15,0), (0,20,0)], removed_peaks=[], o_intensity=[100, 8, 3, 2, 3, 69], SC_length=2.2)
132
+ pathdata2 = PathData(peaks=[2], points=[(1,20,0), (1,20,10), (1,20,20) ], removed_peaks=[RemovedPeakData(0, (0,5))], o_intensity=[38, 2, 20], SC_length=2.3)
133
+ return CellData(pathdata_list=[pathdata1, pathdata2])
134
+
135
+ def test_analyse_celldata(mock_celldata):
136
+ data_frame, foci_absolute_intensity, foci_position_index, dominated_foci_data, trace_median_intensity, trace_thresholds = analyse_celldata(mock_celldata, {'peak_threshold': 0.4, 'threshold_type':'per-trace'})
137
+ assert len(data_frame) == len(mock_celldata.pathdata_list), "Mismatch in dataframe length"
138
+ assert len(foci_absolute_intensity) == len(mock_celldata.pathdata_list), "Mismatch in relative intensities length"
139
+ assert len(foci_position_index) == len(mock_celldata.pathdata_list), "Mismatch in positions length"
140
+
141
+ assert list(map(list, foci_position_index)) == [[0, 5], [2]]
142
+
143
+
144
+ def test_analyse_celldata_per_cell(mock_celldata):
145
+ data_frame, foci_absolute_intensity, foci_position_index, dominated_foci_data, trace_median_intensity, trace_thresholds = analyse_celldata(mock_celldata, {'peak_threshold': 0.4, 'threshold_type':'per-cell'})
146
+ assert len(data_frame) == len(mock_celldata.pathdata_list), "Mismatch in relative intensities length"
147
+ assert len(foci_absolute_intensity) == len(mock_celldata.pathdata_list), "Mismatch in positions length"
148
+ assert len(foci_position_index) == len(mock_celldata.pathdata_list), "Mismatch in position indices length"
149
+ assert list(map(list, foci_position_index)) == [[0, 5], []]
150
+
tests/test_results.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import pytest
3
+ from path_analysis.analyse import *
4
+ from path_analysis.data_preprocess import RemovedPeakData
5
+ import numpy as np
6
+ from math import pi
7
+ import xml.etree.ElementTree as ET
8
+ from PIL import ImageChops
9
+
10
+ from pathlib import Path
11
+
12
+ import matplotlib
13
+ matplotlib.use('Agg')
14
+
15
+ @pytest.fixture(scope="module")
16
+ def script_loc(request):
17
+ '''Return the directory of the currently running test script'''
18
+
19
+ return Path(request.fspath).parent
20
+
21
+ def test_image_1(script_loc):
22
+
23
+ config = { 'sphere_radius': 0.1984125,
24
+ 'peak_threshold': 0.4,
25
+ 'xy_res': 0.0396825,
26
+ 'z_res': 0.0909184,
27
+ 'threshold_type': 'per-cell',
28
+ 'use_corrected_positions': True,
29
+ 'screening_distance': 10,
30
+ }
31
+
32
+ data_loc = script_loc.parent.parent / 'test_data' / 'hei10 ++ 15.11.19 p22s2 image 9'
33
+
34
+
35
+ image_input = data_loc / 'HEI10.tif'
36
+ path_input = data_loc / 'SNT_Data.traces'
37
+
38
+ paths, traces, fig, extracted_peaks = analyse_paths('Cell', image_input, path_input, config)
39
+
40
+ assert np.allclose(extracted_peaks['SNT_trace_length(um)'], [61.47, 70.40, 51.93, 43.94, 62.24], atol=1e-2 )
41
+ assert np.allclose(extracted_peaks['SNT_trace_length(um)'], extracted_peaks['Measured_trace_length(um)'], atol=1e-8 )
42
+ assert list(extracted_peaks['Trace_foci_number']) == [2,3,2,2,3]
43
+
44
+ def test_image_2(script_loc):
45
+
46
+ config = { 'sphere_radius': 0.1984125,
47
+ 'peak_threshold': 0.4,
48
+ 'xy_res': 0.0396825,
49
+ 'z_res': 0.0909184,
50
+ 'threshold_type': 'per-cell',
51
+ 'use_corrected_positions': True,
52
+ 'screening_distance': 10,
53
+ }
54
+
55
+ data_loc = script_loc.parent.parent / 'test_data' / 'z-optimised'
56
+
57
+
58
+ image_input = data_loc / 'HEI10.tif'
59
+ path_input = data_loc / 'ZYP1.traces'
60
+
61
+ paths, traces, fig, extracted_peaks = analyse_paths('Cell', image_input, path_input, config)
62
+
63
+ assert np.allclose(extracted_peaks['SNT_trace_length(um)'], extracted_peaks['Measured_trace_length(um)'], atol=1e-8 )
64
+ assert list(extracted_peaks['Trace_foci_number']) == [2,2,1,2,1]
65
+
66
+ def test_image_3(script_loc):
67
+
68
+ config = { 'sphere_radius': 0.1984125,
69
+ 'peak_threshold': 0.4,
70
+ 'xy_res': 0.0396825,
71
+ 'z_res': 0.1095510,
72
+ 'threshold_type': 'per-trace',
73
+ 'use_corrected_positions': True,
74
+ 'screening_distance': 10,
75
+
76
+ }
77
+
78
+ data_loc = script_loc.parent.parent / 'test_data' / 'arenosa SN A1243 image 18-20230726T142725Z-001' / 'arenosa SN A1243 image 18'
79
+
80
+
81
+ image_input = data_loc / 'HEI10.tif'
82
+ path_input = data_loc / 'SNT_Data.traces'
83
+
84
+ paths, traces, fig, extracted_peaks = analyse_paths('Cell', image_input, path_input, config)
85
+
86
+ assert np.allclose(extracted_peaks['SNT_trace_length(um)'], extracted_peaks['Measured_trace_length(um)'], atol=1e-8 )
87
+ assert list(extracted_peaks['Trace_foci_number']) == [2,1,1,1,2,1,1,1]
88
+
89
+ def test_image_4(script_loc):
90
+
91
+ config = { 'sphere_radius': 10.,
92
+ 'peak_threshold': 0.4,
93
+ 'xy_res': 1,
94
+ 'z_res': 1,
95
+ 'threshold_type': 'per-trace',
96
+ 'use_corrected_positions': True,
97
+ 'screening_distance': 10,
98
+
99
+ }
100
+
101
+ data_loc = script_loc.parent.parent / 'test_data' / 'mammalian 2D-20230821T180708Z-001' / 'mammalian 2D' / '1'
102
+
103
+
104
+ image_input = data_loc / 'C2-Pachytene SIM-1.tif'
105
+ path_input = data_loc / 'SNT_Data.traces'
106
+
107
+ paths, traces, fig, extracted_peaks = analyse_paths('Cell', image_input, path_input, config)
108
+
109
+ assert np.allclose(extracted_peaks['SNT_trace_length(um)'], extracted_peaks['Measured_trace_length(um)'], atol=1e-8 )
110
+
111
+ valid_results = [{1}, {1}, {2, 3}, {1, 2}, {1, 2}, {1}, {1}, {2}, {1}, {1}, {1, 2}, {1}, {1, 2}, {1, 2}, {1}, {1}, {1}, {1}, {1}]
112
+ measured = extracted_peaks['Trace_foci_number']
113
+
114
+ print(measured)
115
+ assert len(measured) == len(valid_results)
116
+ assert(all(m in v for m,v in zip(measured, valid_results)))
117
+
118
+
119
+
120
+ def test_image_5(script_loc):
121
+
122
+ config = { 'sphere_radius': 0.3,
123
+ 'peak_threshold': 0.4,
124
+ 'xy_res': 0.1023810,
125
+ 'z_res': 1,
126
+ 'threshold_type': 'per-trace',
127
+ 'use_corrected_positions': True,
128
+ 'screening_distance': 10,
129
+
130
+ }
131
+
132
+ data_loc = script_loc.parent.parent / 'test_data' / 'mammalian 2D-20230821T180708Z-001' / 'mammalian 2D' / '2'
133
+
134
+
135
+ image_input = data_loc / 'C1-CNTD1FHFH CSHA 1in5000 22612 Slide 6-102-1.tif'
136
+ path_input = data_loc / 'SNT_Data.traces'
137
+
138
+ paths, traces, fig, extracted_peaks = analyse_paths('Cell', image_input, path_input, config)
139
+
140
+ assert np.allclose(extracted_peaks['SNT_trace_length(um)'], extracted_peaks['Measured_trace_length(um)'], atol=1e-8 )
141
+
142
+ valid_results = [1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 1]
143
+ measured = extracted_peaks['Trace_foci_number']
144
+
145
+ assert list(measured) == valid_results
146
+
147
+
148
+
149
+
150
+