ReyaLabColumbia commited on
Commit
c16c74c
·
verified ·
1 Parent(s): 0b31c6b

Delete Colony_Analyzer_AI_zstack2.py

Browse files
Files changed (1) hide show
  1. Colony_Analyzer_AI_zstack2.py +0 -385
Colony_Analyzer_AI_zstack2.py DELETED
@@ -1,385 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- Created on Thu Mar 20 14:23:27 2025
5
-
6
- @author: mattc
7
- """
8
-
9
- import os
10
- import cv2
11
- #this is the huggingface version
12
- # path = '/home/mattc/Documents/ColonyAssaySegformer/'
13
- # file_list = os.listdir(path)
14
- # file_list = [x for x in file_list if (x[-4::]==".tif" or x[-5::]==".tiff")]
15
- def cut_img(path, x):
16
- img_map = {}
17
- img = cv2.imread(path + x)
18
- name = x.split(".")[0]
19
- i_num = img.shape[0]/512
20
- j_num = img.shape[1]/512
21
- count = 1
22
- for i in range(int(i_num)):
23
- for j in range(int(j_num)):
24
- img2 = img[(512*i):(512*(i+1)), (512*j):(512*(j+1))]
25
- cv2.imwrite(path+name+'_part'+str(count)+'.tif', img2)
26
- img_map[count] = path+name+'_part'+str(count)+'.tif'
27
- count +=1
28
- return(img_map)
29
-
30
- import numpy as np
31
-
32
- def stitch(img_map):
33
- for x in img_map:
34
- temp = img_map[x]
35
- img_map[x] = cv2.imread(temp)
36
- if (img_map[x] is None):
37
- img_map[x] = cv2.imread(temp, cv2.IMREAD_UNCHANGED)
38
- os.remove(temp)
39
- rows = [
40
- np.hstack([img_map[1], img_map[2], img_map[3], img_map[4]]), # First row (images 0 to 3)
41
- np.hstack([img_map[5], img_map[6], img_map[7], img_map[8]]), # Second row (images 4 to 7)
42
- np.hstack([img_map[9], img_map[10], img_map[11], img_map[12]]) # Third row (images 8 to 11)
43
- ]
44
-
45
- # Stack rows vertically
46
- return(np.vstack(rows))
47
-
48
- #img_map = cut_img(path, file_list[0])
49
-
50
-
51
- from PIL import Image
52
-
53
-
54
-
55
- import matplotlib.pyplot as plt
56
-
57
- def visualize_segmentation(mask, image=0):
58
- plt.figure(figsize=(10, 5))
59
-
60
- if(not np.isscalar(image)):
61
- # Show original image if it is entered
62
- plt.subplot(1, 2, 1)
63
- plt.imshow(image)
64
- plt.title("Original Image")
65
- plt.axis("off")
66
-
67
- # Show segmentation mask
68
- plt.subplot(1, 2, 2)
69
- plt.imshow(mask, cmap="gray") # Show as grayscale
70
- plt.title("Segmentation Mask")
71
- plt.axis("off")
72
-
73
- plt.show()
74
-
75
- import torch
76
- from transformers import SegformerForSemanticSegmentation
77
- # Load fine-tuned model
78
- model = SegformerForSemanticSegmentation.from_pretrained("ReyaLabColumbia/Segformer_Colony_Counter") # Adjust path
79
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
80
- model.to(device)
81
- model.eval() # Set to evaluation mode
82
-
83
- # Load image processor
84
- from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
85
- image_processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b3-finetuned-cityscapes-1024-1024")
86
-
87
- def preprocess_image(image_path):
88
- image = Image.open(image_path).convert("RGB") # Open and convert to RGB
89
- inputs = image_processor(image, return_tensors="pt") # Preprocess for model
90
- return image, inputs["pixel_values"]
91
-
92
- def postprocess_mask(logits):
93
- mask = torch.argmax(logits, dim=1) # Take argmax across the class dimension
94
- return mask.squeeze().cpu().numpy() # Convert to NumPy array
95
-
96
-
97
- def eval_img(image_path):
98
- # Load and preprocess image
99
- image, pixel_values = preprocess_image(image_path)
100
- pixel_values = pixel_values.to(device)
101
- with torch.no_grad(): # No gradient calculation for inference
102
- outputs = model(pixel_values=pixel_values) # Run model
103
- logits = outputs.logits
104
- # Convert logits to segmentation mask
105
- segmentation_mask = postprocess_mask(logits)
106
- #visualize_segmentation(segmentation_mask,image)
107
- segmentation_mask = cv2.resize(segmentation_mask, (512, 512), interpolation=cv2.INTER_LINEAR_EXACT)
108
- return(segmentation_mask)
109
-
110
-
111
- # for x in img_map:
112
- # mask = eval_img(img_map[x])
113
- # cv2.imwrite(img_map[x], mask)
114
- # del mask,x
115
- # p = stitch(img_map)
116
- # visualize_segmentation(p)
117
-
118
- # num_colony = np.count_nonzero(p == 1) # Counts number of 1s
119
- # num_necrosis = np.count_nonzero(p == 2)
120
-
121
- # num_necrosis/num_colony
122
-
123
- def find_colonies(mask, size_cutoff, circ_cutoff):
124
- binary_mask = np.where(mask == 1, 255, 0).astype(np.uint8)
125
- contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
126
- contoursf = []
127
- areas = []
128
- for x in contours:
129
- area = cv2.contourArea(x)
130
- if (area < size_cutoff):
131
- continue
132
- perimeter = cv2.arcLength(x, True)
133
-
134
- # Avoid division by zero
135
- if perimeter == 0:
136
- continue
137
-
138
- # Calculate circularity
139
- circularity = (4 * np.pi * area) / (perimeter ** 2)
140
- if circularity >= circ_cutoff:
141
- contoursf.append(x)
142
- areas.append(area)
143
- return(contoursf, areas)
144
-
145
- def find_necrosis(mask):
146
- binary_mask = np.where(mask == 2, 255, 0).astype(np.uint8)
147
- contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
148
- return(contours)
149
-
150
- # contour_image = np.zeros_like(p)
151
- # contours = find_necrosis(p)
152
- # cv2.drawContours(contour_image, contours, -1, (255), 2)
153
- # visualize_segmentation(contour_image)
154
- import pandas as pd
155
- def compute_centroid(contour):
156
- M = cv2.moments(contour)
157
- if M["m00"] == 0: # Avoid division by zero
158
- return None
159
- cx = int(M["m10"] / M["m00"])
160
- cy = int(M["m01"] / M["m00"])
161
- return (cx, cy)
162
-
163
-
164
- def contours_overlap_using_mask(contour1, contour2, image_shape=(1536, 2048)):
165
- """Check if two contours overlap using a bitwise AND mask."""
166
- import numpy as np
167
- import cv2
168
- mask1 = np.zeros(image_shape, dtype=np.uint8)
169
- mask2 = np.zeros(image_shape, dtype=np.uint8)
170
-
171
-
172
- # Draw each contour as a white shape on its respective mask
173
- cv2.drawContours(mask1, [contour1], -1, 255, thickness=cv2.FILLED)
174
- cv2.drawContours(mask2, [contour2], -1, 255, thickness=cv2.FILLED)
175
-
176
-
177
- # Compute bitwise AND to find overlapping regions
178
- overlap = cv2.bitwise_and(mask1, mask2)
179
-
180
- return np.any(overlap)
181
-
182
- def analyze_colonies(mask, size_cutoff, circ_cutoff):
183
- colonies,areas = find_colonies(mask, size_cutoff, circ_cutoff)
184
- necrosis = find_necrosis(mask)
185
-
186
- data = []
187
-
188
- for x in range(len(colonies)):
189
- colony = colonies[x]
190
- colony_area = areas[x]
191
- centroid = compute_centroid(colony)
192
-
193
- # Check if any necrosis contour is inside the colony
194
- necrosis_area = 0
195
- nec_list =[]
196
- for nec in necrosis:
197
- # Check if the first point of the necrosis contour is inside the colony
198
- if contours_overlap_using_mask(colony, nec):
199
- nec_area = cv2.contourArea(nec)
200
- necrosis_area += nec_area
201
- nec_list.append(nec)
202
-
203
- data.append({
204
- "colony_area": colony_area,
205
- "necrosis_area": necrosis_area,
206
- "centroid": centroid,
207
- "percent_necrosis": necrosis_area/colony_area,
208
- "contour": colony,
209
- "nec_contours": nec_list
210
- })
211
-
212
- # Convert results to a DataFrame
213
- df = pd.DataFrame(data)
214
- df.index = range(1,len(df.index)+1)
215
- return(df)
216
-
217
-
218
- def contour_overlap(contour1, contour2, centroid1, centroid2, area1, area2, centroid_thresh=30, area_thresh = .4, img_shape = (1536, 2048)):
219
- """
220
- Determines the overlap between two contours.
221
- Returns:
222
- 0: No overlap
223
- 1: Overlap but does not meet strict conditions
224
- 2: Overlap >= 80% of the larger contour and centroids are close
225
- """
226
- # Create blank images
227
- img1 = np.zeros(img_shape, dtype=np.uint8)
228
- img2 = np.zeros(img_shape, dtype=np.uint8)
229
-
230
- # Draw filled contours
231
- cv2.drawContours(img1, [contour1], -1, 255, thickness=cv2.FILLED)
232
- cv2.drawContours(img2, [contour2], -1, 255, thickness=cv2.FILLED)
233
-
234
- # Compute overlap
235
- intersection = cv2.bitwise_and(img1, img2)
236
- intersection_area = np.count_nonzero(intersection)
237
-
238
- if intersection_area == 0:
239
- return 0 # No overlap
240
-
241
- # Compute centroid distance
242
- centroid_distance = float(np.sqrt(abs(centroid1[0]-centroid2[0])**2 + abs(centroid1[1]-centroid2[1])**2))
243
- # Check percentage overlap relative to the larger contour
244
- overlap_ratio = intersection_area/max(area1, area2)
245
- if overlap_ratio >= area_thresh and centroid_distance <= centroid_thresh:
246
- if area1 > area2:
247
- return(2)
248
- else:
249
- return(3)
250
- else:
251
- return 1 # Some overlap but not meeting strict criteria
252
-
253
- def compare_frames(frame1, frame2):
254
- for i in range(1, len(frame1)+1):
255
- if frame1.loc[i,"exclude"] == True:
256
- continue
257
- for j in range(1, len(frame2)+1):
258
- if frame2.loc[j,"exclude"] == True:
259
- continue
260
- temp = contour_overlap(frame1.loc[i, "contour"], frame2.loc[j, "contour"], frame1.loc[i, "centroid"], frame2.loc[j, "centroid"], frame1.loc[i, "colony_area"], frame2.loc[j, "colony_area"])
261
- if temp ==2:
262
- frame2.loc[j,"exclude"] = True
263
- elif temp ==3:
264
- frame1.loc[i, "exclude"] = True
265
- break
266
- frame1 = frame1[frame1["exclude"]==False]
267
- frame2 = frame2[frame2["exclude"]==False]
268
- df = pd.concat([frame1, frame2], axis=0)
269
- df.index = range(1,len(df.index)+1)
270
- return(df)
271
-
272
- def main(args):
273
- path = args[0]
274
- files = args[1]
275
- min_size = args[2]
276
- min_circ = args[3]
277
- colonies = {}
278
- for x in files:
279
- img_map = cut_img(path, x)
280
- for z in img_map:
281
- mask = eval_img(img_map[z])
282
- cv2.imwrite(img_map[z], mask)
283
- del mask,z
284
- p = stitch(img_map)
285
- frame = analyze_colonies(p, min_size, min_circ)
286
- frame["source"] = x
287
- frame["exclude"] = False
288
- if isinstance(colonies, dict):
289
- colonies = frame
290
- else:
291
- colonies = compare_frames(frame, colonies)
292
- counts = {}
293
- for x in files:
294
- counts[x] = list(colonies["source"]).count(x)
295
- best = [x, counts[x]]
296
- del x
297
- for x in counts:
298
- if counts[x] > best[1]:
299
- best[0] = x
300
- best[1] = counts[x]
301
- del x, counts
302
- best = best[0]
303
- img = cv2.imread(path + best)
304
- for x in files:
305
- if x == best:
306
- continue
307
- mask = np.zeros_like(cv2.cvtColor(img, cv2.COLOR_BGR2GRAY))
308
- contours = colonies[colonies["source"]==x]
309
- contours = list(contours["contour"])
310
- cv2.drawContours(mask, contours, -1, 255, thickness=cv2.FILLED)
311
- # Extract all ROIs from the source image at once
312
- src_image = cv2.imread(path +x)
313
- roi = cv2.bitwise_and(src_image, src_image, mask=mask)
314
- # Paste the extracted regions onto the destination image
315
- np.copyto(img, roi, where=(mask[..., None] == 255))
316
- try:
317
- del x, mask, src_image, roi, best, contours
318
- except:
319
- pass
320
-
321
- img = cv2.copyMakeBorder(img,top=0, bottom=10,left=0,right=10, borderType=cv2.BORDER_CONSTANT, value=[255, 255, 255])
322
- colonies = colonies.sort_values(by=["colony_area"], ascending=False)
323
- colonies = colonies[colonies["colony_area"]>= min_size]
324
- colonies.index = range(1,len(colonies.index)+1)
325
- #nearby is a boolean list of whether a colony has overlapping colonies. If so, labelling positions change
326
- nearby = [False]*len(colonies)
327
- areas = list(colonies["colony_area"])
328
- for i in range(len(colonies)):
329
- cv2.drawContours(img, [list(colonies["contour"])[i]], -1, (0, 255, 0), 2)
330
- cv2.drawContours(img, list(colonies['nec_contours'])[i], -1, (0, 0, 255), 2)
331
- coords = list(list(colonies["centroid"])[i])
332
- if coords[0] > 1950:
333
- #if a colony is too close to the right edge, makes the label move to left
334
- coords[0] = 1950
335
- for j in range(len(colonies)):
336
- if j == i:
337
- continue
338
- coords2 = list(list(colonies["centroid"])[j])
339
- if ((abs(coords[0] - coords2[0]) + abs(coords[1] - coords2[1])) <= 40):
340
- nearby[i] = True
341
- break
342
- if nearby[i] ==True:
343
- #If the colony has nearby colonies, this adjusts the labels so they are smaller and are positioned based on the approximate radius of the colony
344
- # a random number is generated, and based on that, the label is put at the top or bottom, left or right
345
- radius= int(np.sqrt(areas[i]/3.1415)*.9)
346
- n = np.random.random()
347
- if n >.75:
348
- new_x = min(coords[0] + radius, 2000)
349
- new_y = min(coords[1] + radius, 1480)
350
- elif n >.5:
351
- new_x = min(coords[0] + radius, 2000)
352
- new_y = max(coords[1] - radius, 50)
353
- elif n >.25:
354
- new_x = max(coords[0] - radius, 0)
355
- new_y = min(coords[1] + radius, 1480)
356
- else:
357
- new_x = max(coords[0] - radius, 0)
358
- new_y = max(coords[1] - radius, 50)
359
- cv2.putText(img, str(colonies.index[i]), (new_x,new_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
360
- del n, radius, new_x, new_y
361
- else:
362
- cv2.putText(img, str(colonies.index[i]), coords, cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 0), 2)
363
- del nearby, areas
364
- colonies = colonies.drop('contour', axis=1)
365
- colonies = colonies.drop('nec_contours', axis=1)
366
- colonies = colonies.drop('exclude', axis=1)
367
- img = cv2.copyMakeBorder(img,top=10, bottom=0,left=10,right=0, borderType=cv2.BORDER_CONSTANT, value=[255, 255, 255])
368
-
369
- colonies.insert(loc=0, column="Colony Number", value=[str(x) for x in range(1, len(colonies)+1)])
370
- total_area_dark = sum(colonies['necrosis_area'])
371
- total_area_light = sum(colonies['colony_area'])
372
- ratio = total_area_dark/(abs(total_area_light)+1)
373
-
374
- colonies.loc[len(colonies)+1] = ["Total", total_area_light, total_area_dark, None, ratio, None]
375
- Parameters = pd.DataFrame({"Minimum colony size in pixels":[min_size], "Minimum colony circularity":[min_circ]})
376
- with pd.ExcelWriter(path+"Group_analysis_results.xlsx") as writer:
377
- colonies.to_excel(writer, sheet_name="Colony data", index=False)
378
- Parameters.to_excel(writer, sheet_name="Parameters", index=False)
379
- caption = np.ones((150, 2068, 3), dtype=np.uint8) * 255 # Multiply by 255 to make it white
380
- cv2.putText(caption, "Total area necrotic: "+str(total_area_dark)+ ", Total area living: "+str(total_area_light)+", Ratio: "+str(ratio), (40, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 3)
381
-
382
-
383
-
384
- cv2.imwrite(path+'Group_analysis_results.png', np.vstack((img, caption)))
385
- return(np.vstack((img, caption)))