File size: 23,213 Bytes
bfcf15b
953db0f
c33db6f
397a6e0
747ae22
0b0fbff
c7b7c38
061c485
c7b7c38
bc77755
fc26e6a
bc77755
1be6924
 
708cb85
54ffbc8
708cb85
 
4473d1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff5529a
3e58663
e81e959
00fc678
b12f66c
bfcf15b
3e58663
d8cca02
1b41e67
f786d0a
 
96cefb1
 
6035558
 
96cefb1
 
d5c031d
7b20888
061c485
5fd6610
061c485
7ea4256
891de3b
5fd6610
9354b89
061c485
5fd6610
9354b89
061c485
 
7ac6b8c
c510313
48de0d0
3ba701f
4001908
3ba701f
 
 
7ac6b8c
 
c510313
7ac6b8c
c510313
7ac6b8c
c510313
31abf6a
7ac6b8c
061c485
7ac6b8c
7b20888
e483761
c7b7c38
31a6cfe
 
891de3b
747ae22
ec71e7b
ddeae44
ec71e7b
e442097
ec71e7b
747ae22
3927601
c7b7c38
 
 
 
c9c7d64
0a0fee8
c9c7d64
 
c7b7c38
216cc2d
59d9fe4
c064b06
 
 
31a6cfe
c7b7c38
 
747ae22
 
 
 
48de0d0
 
 
ddeae44
b22cdd2
 
ddeae44
 
b22cdd2
0494df2
ddeae44
 
 
 
 
 
 
 
747ae22
 
6e4453c
 
747ae22
 
ddeae44
747ae22
 
 
 
 
 
 
 
521caba
c064b06
 
 
 
 
 
89f7c7f
c064b06
 
521caba
c064b06
 
 
 
 
3927601
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6eb204
3927601
609ca1e
3927601
609ca1e
 
 
 
3927601
 
 
609ca1e
 
 
 
3927601
 
 
 
6ddc519
 
 
3927601
 
 
 
 
 
 
 
 
7b20888
1b41e67
c7b7c38
72899e4
b4e0584
1b41e67
 
c7b7c38
 
1b41e67
31a6cfe
c7b7c38
1b41e67
3927601
937a2b5
 
 
3927601
7a7e782
c1b7442
297f270
7ac6b8c
c510313
7ac6b8c
 
c510313
 
2c918b3
c510313
 
 
75871e4
 
7ac6b8c
6a9a9fa
1b41e67
f5ae7ba
 
 
 
6ac0b94
f5ae7ba
d319067
 
6ac0b94
d361bef
7ff8c52
ca29048
 
 
 
 
 
 
 
 
 
 
04f82a3
3927601
5fd6610
297f270
dc3e688
5f33f15
c1b7442
d361bef
73979a5
1df1236
73979a5
d89f61c
e483761
73979a5
 
 
285f9db
73979a5
 
 
 
 
36082ac
4a242d2
 
d89f61c
2bff5f5
e73b996
73979a5
c5b9d01
1b41e67
a16fa02
475d498
a16fa02
 
 
fb3a853
50af284
fb3a853
a16fa02
6d1d144
 
a16fa02
6d1d144
a16fa02
 
 
 
 
475d498
a16fa02
d89f61c
 
e8f234a
475d498
67f3e69
e8f234a
 
a16fa02
475d498
 
 
 
 
 
 
 
 
 
a16fa02
475d498
a16fa02
 
 
 
 
358e1af
 
 
 
daba1ff
358e1af
56b2ef3
a16fa02
 
ab1e59e
a16fa02
 
 
 
710bb7d
ab1e59e
a16fa02
673e356
 
0a714b9
35d24e2
673e356
 
35d24e2
673e356
 
35d24e2
673e356
 
35d24e2
52f444b
 
 
 
 
 
 
 
 
 
 
 
fe32861
55c8705
3e58663
4a18925
13c3232
1c7ad57
b352f9b
d7f078e
a6aaa13
6856da6
661ed31
881930b
661ed31
 
a6aaa13
79c0151
757776c
8ecdbc8
757776c
8ecdbc8
 
6335c21
8ecdbc8
6335c21
86238bc
 
19dab0b
0e78262
654a528
3ab3f60
1c7ad57
6335c21
 
 
8ecdbc8
a6aaa13
 
 
 
d7f078e
6e18f3b
842b426
6e18f3b
79c0151
6e18f3b
c510313
 
34441c6
3a9a03f
44a9e44
3a9a03f
6a2cf84
7df4bba
 
6a2cf84
c510313
7b20888
fe32861
bc77755
8ecdbc8
4a9f8e4
985eae2
a3468c5
985eae2
47b9453
cf5763d
 
2ed7a2a
 
fa837b5
cf5763d
fa837b5
4a9f8e4
2ed7a2a
fa837b5
cf5763d
fa837b5
2ed7a2a
fa837b5
6335c21
 
13846d5
6335c21
 
 
 
 
 
47b9453
6335c21
 
47b9453
6335c21
 
 
 
 
 
 
 
 
 
e2dd330
7df4bba
 
 
 
4ae9b2a
36082ac
4ae9b2a
d6c4df5
1864ada
b9d72c3
4ae9b2a
 
7df4bba
36082ac
7df4bba
13846d5
7df4bba
 
 
d7f078e
1b41e67
6cf5c21
ee0592a
1b41e67
f89a150
216cc2d
c93a6a0
ff8a974
 
1b41e67
31c2d5c
1c6b89c
 
 
216cc2d
1c6b89c
31c2d5c
7b8adee
 
 
76772f3
c93a6a0
 
1b41e67
c5b9d01
 
1b41e67
ff8a974
a492ef8
 
1b41e67
 
 
0594bef
ddeae44
d8cca02
ff8a974
d8cca02
1b41e67
 
 
6840fb0
 
 
 
 
 
 
 
 
0534456
1b41e67
 
 
76d3c68
686fa94
1b41e67
ddeae44
 
a492ef8
d7e453e
1b41e67
 
 
b63cdd8
d65d12c
 
 
 
 
 
 
 
 
 
f2e5e11
d65d12c
 
 
 
 
 
 
 
 
937a2b5
d65d12c
216cc2d
d1c2c47
1b41e67
 
f7202eb
a78471c
 
 
 
 
 
 
 
 
 
1b41e67
a78471c
b63cdd8
f2e5e11
 
d7e453e
b63cdd8
f2e5e11
 
a78471c
d7e453e
 
1b41e67
b4c6aec
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
import gradio as gr
import spaces

import tifffile as tiff
import zarr
import numpy as np
import os
import cv2
from PIL import Image 
from skimage.feature import peak_local_max
import scipy as sc

import huggingface_hub

# Available backend options are: "jax", "torch", "tensorflow".
os.environ["KERAS_BACKEND"] = "torch"
	
import keras

try:
    from keras.src import api_export
    api_export.REGISTERED_NAMES_TO_OBJS["keras.models.functional.Functional"] = keras.src.models.functional.Functional
    api_export.REGISTERED_NAMES_TO_OBJS["keras.ops.numpy.Concatenate"] = keras.src.ops.numpy.Concatenate
    api_export.REGISTERED_NAMES_TO_OBJS["keras.ops.numpy.Flip"] = keras.src.ops.numpy.Flip
    api_export.REGISTERED_NAMES_TO_OBJS["keras.ops.numpy.GetItem"] = keras.src.ops.numpy.GetItem
    api_export.REGISTERED_NAMES_TO_OBJS["keras.ops.numpy.Stack"] = keras.src.ops.numpy.Stack
    api_export.REGISTERED_NAMES_TO_OBJS["keras.ops.numpy.Absolute"] = keras.src.ops.numpy.Absolute
    api_export.REGISTERED_NAMES_TO_OBJS["keras.ops.nn.Conv"] = keras.src.ops.nn.Conv
    api_export.REGISTERED_NAMES_TO_OBJS["keras.backend.torch.optimizers.torch_adam.Adam"] = keras.src.optimizers.Adam
except ModuleNotFoundError:
    print('pleasssse')
    
    pass  # Not necessary for this version of Keras


import keras.saving

model_adresses = {"Intestinal organoids (0.32x0.32x2.0 um)": 'sjtans/organoids_pytorch',
                  "c. Elegans embryo (0.1x0.1x1.0 um)": 'sjtans/elegans_pytorch'}



fp0 = np.zeros((96, 128), dtype = np.uint8)

fp1 = np.ones((96, 128), dtype = np.uint8)*200

from huggingface_hub import hf_hub_download

def download_model(model):
    return hf_hub_download(repo_id=model, filename="model.keras")



# generic image reader
def imread(filepath):
    print('imread')
    fpath, fext = os.path.splitext(filepath)
    
    if fext in ['.tiff', '.tif']:
        print('imread_tiff')
        img = tiff.imread(filepath)
    else:
        print('imread_cv2')
        img = cv2.imread(filepath)

    return img

def check_dims(filepath):

    tif = tiff.TiffFile(filepath)

    store = tif.aszarr()
    img = zarr.open(store, mode='r', chunks=None) 
    store.close()

    if img.ndim==3:
        return img.shape[0], None
    if img.ndim==4:
        return img.shape[0], img.shape[1]
    if img.ndim==5:
        return img.shape[1], img.shape[2]
    else:
        raise ValueError("TIF has wrong dimensions")
    

# tiff volume to png slice
def tif_view(filepath, z, c=0, show_depth=True):
    fpath, fext = os.path.splitext(filepath)
    print('tif'+filepath)
    print('tif'+ fext)
    if fext in ['.tiff', '.tif']:
        img = get_slice(filepath, z, c = c)

        # get slice above and below
        if show_depth:
            img = np.stack([img, get_slice(filepath, z-1, c = c), get_slice(filepath, z+1, c = c)],axis=-1)
        else:
            img = np.stack([img]*3,axis=-1)
        
        Ly, Lx, nchan = img.shape
        imgi = np.zeros((Ly, Lx, 3))
        nn = np.minimum(3, img.shape[-1])
        imgi[:,:,:nn] = img[:,:,:nn]

        imgi = imgi - np.min(imgi)
        imgi = imgi/(np.max(imgi)+0.0000001)
        imgi = (255. * imgi)
        
        filepath = fpath+'z'+str(z)+'c'+str(c)+'.png'
        tiff.imwrite(filepath, imgi.astype('uint8'))
    else: 
        raise ValueError("not a TIF/TIFF")
        
    print('tif'+filepath)   
    return filepath

def get_slice(filepath, z, c=0):

    tif = tiff.TiffFile(filepath)
    
    store = tif.aszarr()
    img = zarr.open(store, mode='r', chunks=None) 
    store.close()

    print(z)
    
    if img.ndim==3:
        if (z>=img.shape[0]) | (z<0):
            print('z to big')
            return np.zeros((img.shape[1], img.shape[2]))
    if img.ndim==4:
        if (z>=img.shape[0]) | (z<0):
            return np.zeros((img.shape[2], img.shape[3]))
    if img.ndim==5:
        if (z>=img.shape[1]) | (z<0):
            return np.zeros((img.shape[3], img.shape[4]))

            
    if img.ndim==2:
        raise ValueError("TIF has only two dimensions")
    if img.ndim==3:
        img = img[z,:,:]
    if img.ndim==4:
        img = img[z,c,:,:]
        
    # select first timepoint
    if img.ndim==5:
        img = img[0,z,c,:,:]
        print(img.shape)
    if img.ndim>5:
        raise ValueError("TIF cannot have more than five dimensions")
    return img    

def get_volume(filepath, c=0):

    img = tiff.imread(filepath)
    print(img.shape)
    if img.ndim==2:
        raise ValueError("TIF has only two dimensions")
    if img.ndim==4:
        img = img[:,c,:,:]
    # select first timepoint
    if img.ndim==5:
        img = img[0,:,c,:,:]
        print(img.shape)
    if img.ndim>5:
        raise ValueError("TIF cannot have more than five dimensions")
    return img    

def tif_view_3D(filepath, z):
    fpath, fext = os.path.splitext(filepath)
    print('tif'+filepath)
    print('tif'+ fext)

    # assumes (t,)z,(c,)y,x for now
    if fext in ['.tiff', '.tif']:
        img = tiff.imread(filepath)
        print(img.shape)
        if img.ndim==2:
            raise ValueError("TIF has only two dimensions")

        # select first timepoint
        if img.ndim==5:
            img = img[0,:,:,:,:]
            print(img.shape)

        #distinguishes between z,y,x and z,c,y,x
        if img.ndim==4:
            img = img[z,:,:,:]
            print(img.shape)
        elif img.ndim==3:
            img = img[z,:,:]
            print(img.shape)
            img = np.tile(img[:,:,np.newaxis], [1,1,3])
        else:
            raise ValueError("TIF cannot have more than five dimensions")

        imin = np.argmin(img.shape)
        img = np.moveaxis(img, imin, 2)
        print(img.shape)

        Ly, Lx, nchan = img.shape
        imgi = np.zeros((Ly, Lx, 3))
        nn = np.minimum(3, img.shape[-1])
        imgi[:,:,:nn] = img[:,:,:nn]

        imgi = imgi/(np.max(imgi)+0.0000001)
        imgi = (255. * imgi)
        
        filepath = fpath+'.png'
        tiff.imwrite(filepath, imgi.astype('uint8'))
    else: 
        raise ValueError("not a TIF/TIFF")
        
    print('tif'+filepath)   
    return filepath

# function to change image appearance
def norm_path(filepath):
    img = imread(filepath)
    img = img/(np.max(img)+0.0000001)
    #img = np.clip(img, 0, 1)
    fpath, fext = os.path.splitext(filepath)
    filepath = fpath +'.png'
    pil_image = Image.fromarray((255. * img).astype(np.uint8))
    pil_image.save(filepath)
    #imsave(filepath, pil_image)
    print('norm'+filepath)
    return filepath 

def update_image(filepath, z): 
    print('update_img')   
    #for f in filepath:
        #f = tif_view(f, z)
    filepath_show = tif_view(filepath[-1], z)
    filepath_show = norm_path(filepath_show)
    print(filepath_show)
    print(filepath)

    max_z, num_c = check_dims(filepath[-1])

    z = min(z, max_z)

    if num_c is None:
        visible_c = False
        num_c = 10
    else:
        visible_c = True

    print(visible_c)
    
    return (filepath_show,  [((5, 5, 10, 10), 'nothing')]), filepath, (fp0, [((5, 5, 10, 10), 'nothing')]), None, None, gr.Slider(0, max_z-1, value = z, visible=True),  gr.Slider(0, num_c-1, visible=visible_c)

def update_with_example(filepath):
    print('update_btn')
    print(filepath)
    fpath, fext = os.path.splitext(filepath)

    filepath = fpath+ '.tif' 
    filepath = filepath.split('/')[-1]
    filepath = "./gradio_examples/"+filepath
    
    return update_image([filepath], z=10)

def example(filepath):
    print('update_btn')
    print(filepath)
    filepath_show = filepath
    fpath, fext = os.path.splitext(filepath)

    filepath = fpath+ '.tif' 
    filepath = filepath.split('/')[-1]
    filepath = "./gradio_examples/"+filepath
    print(filepath)
    return(filepath_show)

def update_button(filepath, z):
    print('update_btn')
    print(filepath)
    filepath_show = tif_view(filepath, z)
    filepath_show = norm_path(filepath_show)
    print(filepath_show)
    return (filepath_show, [((5, 5, 10, 10), 'nothing')]), [filepath], (fp0, [((5, 5, 10, 10), 'nothing')], z)

def update(filepath, filepath_result, filepath_coordinates, z, c): 
    print('update_img')   

    filepath_show = tif_view(filepath[-1], z, c=c)
    filepath_show = norm_path(filepath_show)

    if isinstance(filepath_result, str):
        filepath_result_show = tif_view(filepath_result, z, show_depth=False)
        filepath_result_show = norm_path(filepath_result_show)
    else:
        filepath_result_show = fp0
    print(filepath_show)
    print(filepath)

    if filepath_coordinates is None:
        display_boxes = []
    else:        
        print(imread(filepath_show).shape)
        display_boxes = filter_coordinates_alt(filepath_coordinates, z, imread(filepath_show).shape[0:2])
    
    return (filepath_show, display_boxes), (filepath_result_show, display_boxes)


def filter_coordinates_alt(filepath_coordinates, z, image_shape=(512, 512)):

    depth = 3
    coordinates = np.loadtxt(filepath_coordinates, delimiter=",")

    coordinates = coordinates.astype('int')

    print(coordinates)

    print(np.abs(coordinates[:,0]-z))
    coordinates = coordinates[np.abs(coordinates[:,0]-z)<depth, :]
    
    print(coordinates)
    #xy_coordinates = coordinates[:, (2,1)]
    #rel_z = np.abs(coordinates[:, 0]-z)
    #rel_z = rel_z[:, np.newaxis]

    boxes = np.zeros(image_shape)
    
    x_coord = tuple(coordinates[:,1])
    y_coord = tuple(coordinates[:,2])
    sizes = tuple(4-np.abs(coordinates[:,0]-z))

    print(sizes)
    for x, y, size in zip(x_coord, y_coord, sizes):
        boxes = draw_box(boxes, x, y, size)
    
    return [(boxes,'nothing')]

def draw_box(array, x, y, size):
    x0 = max(x-size, 0)
    y0 = max(y-size, 0)

    x1 = min(x+size+1, array.shape[0]-1)
    y1 = min(y+size+1, array.shape[1]-1)

    array[x0:x1, y0:y1] = 1

    return array

def add_boxes_norm_path(filepath, boxes):
    img = imread(filepath)
    img = img/(np.max(img)+0.0000001)

    #boxes = np.stack([boxes, np.zeros(boxes.shape),np.zeros(boxes.shape)], axis=-1).astype('int')
    boxes = np.stack([boxes]*3, axis=-1).astype('int')

    print(img.shape)
    print(np.sum(boxes))
    print(boxes.shape)
    img = np.where(boxes>0, boxes, img)

    fpath, fext = os.path.splitext(filepath)
    filepath = fpath + '_with_boxes'+'.png'
    pil_image = Image.fromarray((255. * img).astype(np.uint8))
    pil_image.save(filepath)
    #imsave(filepath, pil_image)
    print('norm_with_box'+filepath)
    
    return filepath
    
def loss(y_true, y_pred):
    # Calculate weighted mean square error
    return None
    
def position_precision(y_true, y_pred):
    return loss(y_true, y_pred)

def position_recall(y_true, y_pred):
    return loss(y_true, y_pred)

def overcount(y_true, y_pred):
    return loss(y_true, y_pred)

@spaces.GPU(duration=60)
def run_model_gpu60(model, tile):
    tensor = keras.ops.convert_to_tensor(tile)
    model = keras.saving.load_model(model, custom_objects={'loss': loss,
                                                          'position_precision': position_precision,
                                                          'position_recall': position_recall,
                                                          'overcount': overcount})
    result = model(tensor).cpu().detach().numpy()
    return result



def detect_cells(filepath, c, model, rescale_z, rescale_xy, progress=gr.Progress()):
    model = download_model(model_adresses[model])
    #model = tf.keras.models.load_model(model_adresses[model], compile=False)
   # model = keras.saving.load_model("hf://sjtans/OrganoidTracker2_pytorch", compile=False)
   # model = keras.saving.load_model(model, compile=False)
    xy_tile =32
    img = get_volume(filepath[-1], c = c)
    
    original_shape = img.shape
    img = sc.ndimage.zoom(img, (rescale_z, rescale_xy, rescale_xy))

    background = np.quantile(img, 0.75)

    img = np.maximum(img, background)-background
    
    img = img/np.max(img)

    img_padded= pad(img)
    
    print(img_padded.shape)
    tiles = split_z(img_padded)
    results = []
    print(tiles)
    for tile in tiles:
        tile = np.tile(tile[:,:,:,np.newaxis], [1,1,2])
        tile= tile[np.newaxis,:,:,:,:]
        
        #result = model(tensor).numpy()
        result = run_model_gpu60(model, tile)
        # remove buffer
        result = result[0, :, xy_tile//2 : -xy_tile//2, xy_tile//2 : -xy_tile//2, 0]
        results.append(result)

    result = reconstruct_z(results)
    print(result.shape)

    result = sc.ndimage.zoom(result, (1/rescale_z, 1/rescale_xy, 1/rescale_xy))
    
    result = result[0:original_shape[0],0:original_shape[1], 0:original_shape[2]]
    print(result.shape)
    print(filepath)
    fpath, fext = os.path.splitext(filepath[-1])
    filepath_result = fpath+'result'+'.tiff'
    
    tiff.imwrite(filepath_result, result)
   # filepath_result_show = tif_view(filepath_result, z, show_depth=False)
    #filepath_result_show = norm_path(filepath_result_show)
    result = result/np.max(result)
    #coordinates = peak_local_max(result, min_distance=2, threshold_abs=0.2,  exclude_border=False)
    coordinates = peak_local_max(result, min_distance=2, footprint=np.ones((5//rescale_z, 13//rescale_xy, 13//rescale_xy)), threshold_abs=0.2,  exclude_border=False)

    print(coordinates)
    filepath_coordinates = fpath+'coordinates'+'.csv'
    np.savetxt(filepath_coordinates, coordinates, delimiter=",")

    #display_boxes = filter_coordinates(filepath_coordinates, z)

    return filepath_result, filepath_coordinates, (fp0, [((5, 5, 10, 10), 'nothing')])#, (filepath_result_show, display_boxes)

def pad(img, z_tile = 32, z_buffer=2, xy_tile = 32):

    if img.shape[0]<z_tile:
        pad_z = z_tile - img.shape[0] 
    elif np.mod(img.shape[0], z_tile)>0:
        pad_z = z_tile-np.mod(img.shape[0], z_tile-z_buffer)
    else:
        pad_z = 0
        
    if np.mod(img.shape[1], xy_tile)>0:
        pad_y = xy_tile-np.mod(img.shape[1], xy_tile) 
    else:
        pad_y = 0

    if np.mod(img.shape[2], xy_tile)>0:
        pad_x = xy_tile-np.mod(img.shape[2], xy_tile)
    else:
        pad_x = 0
    
    return np.pad(img, ((0, pad_z), ( xy_tile//2  , pad_y +  xy_tile//2  ), ( xy_tile//2  , pad_x+ xy_tile//2  )))

def split_z(img, z_tile=32, z_buffer=2):
    
    if img.shape[0]==32:
        return([img])

    tiles = []
    height = 0

    while height<(img.shape[0]-z_buffer):
        tiles.append(img[height:(height+z_tile), :, :])
        height = height+z_tile-z_buffer
        print(height)

    return tiles

def reconstruct_z(tiles, z_tile=32, z_buffer=2):
    
    if len(tiles)==1:
        return tiles[0]

    tiles = [tile[0:(z_tile-z_buffer), :, :] for tile in tiles]

    return np.concatenate(tiles, axis = 0)

def filter_coordinates(filepath_coordinates, z):

    coordinates = np.loadtxt(filepath_coordinates, delimiter=",")
    print(coordinates)
    coordinates = coordinates[np.abs(coordinates[:,0]-z)<3, :]
    print(coordinates)
    xy_coordinates = coordinates[:, (2,1)]
    rel_z = np.abs(coordinates[:, 0]-z)
    rel_z = rel_z[:, np.newaxis]

    print(rel_z)
    
    boxes = np.concatenate((xy_coordinates-4+rel_z, xy_coordinates+4-rel_z), axis=1).astype('uint32')
    print(boxes)
    boxes = [(tuple(box.tolist()),'nothing') for box in boxes]

    print(boxes)
    return boxes
    
with gr.Blocks(title = "Hello", 
                css=""".gradio-container {background:green;} 
                        #examples { background:green;}""") as demo:


    gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:25pt; font-weight:bold; text-align:center; color:white;">OrganoidTracker 2.0 for 3D cell tracking 
            <a style="color:#cfe7fe; font-size:14pt;" href="https://www.biorxiv.org/content/10.1101/2024.10.11.617799v1" target="_blank">[paper]</a> 
            <a style="color:#cfe7fe; font-size:14pt;" href="https://organoidtracker.org" target="_blank">[website]</a> 
            <a style="color:#cfe7fe; font-size:14pt;" href="https://jvzonlab.github.io/OrganoidTracker/index.html" target="_blank">[github]</a>
            </div>""")
    gr.HTML("""<h4 style="color:white;"> What is this?:<p> </h4>
                    <ul>
                    <li style="color:white;">Test the performance of our pre-trained networks on your data.
                    <li style="color:white;">We implement only the initial cell detection step, but this is a good performance indicator for the other steps.
                    <li style="color:white;">Does not work? OrganoidTracker 2.0 allows for the easy creation of ground truth datasets and training of new neural networks. 
                    <ul>
                    """)

    #filepath = ""
    with gr.Row():
        with gr.Column(scale=2):
            # <a style="color:white; font-size:14pt;" href="https://www.youtube.com/watch?v=KIdYXgQemcI" target="_blank">[talk]</a>                        

            
            #input_image = gr.Image(label = "Input", type = "filepath")
            input_image = gr.AnnotatedImage(label = "Input", show_legend=False, color_map = {'nothing': '#FFFF00'})

            gr.HTML("""<h4 style="color:white;">You may need to login/refresh for 5 minutes of free GPU compute per day. </h4>""")


            with gr.Row():
                with gr.Column(scale=1):                    
                    with gr.Row():
                        depth = gr.Slider(0, 100, step=1, label = 'z-depth', value = 10, visible=False)
                        channel = gr.Slider(0, 100, label = 'channel', value = 0, visible=False)

                    up_btn = gr.UploadButton("Upload image volume (.tif/.tiff)", visible=True,  file_count = "multiple")  

                    #gr.HTML("""<h4 style="color:white;"> Note2: Only the first image of a tif will display the segmentations, but you can download segmentations for all planes. </h4>""")
                    
                with gr.Column(scale=1):

                    model = gr.Dropdown(
                            model_adresses.keys(), label="Detection model (with resolutions)", info="Will add more models later!"
                                )

                    with gr.Row():
                        rescale_xy = gr.Slider(0.2, 2, step=0.1, label = 'resize xy', value = 1)
                        rescale_z = gr.Slider(0.2, 2, step=0.1,label = 'resize z', value = 1)
                        
                    send_btn = gr.Button("Run cell detection")
                    
        with gr.Column(scale=2):     
            #
            #output_image = gr.Image(label = "Output", type = "filepath")
            output_image = gr.AnnotatedImage(label = "Output", show_legend=False, color_map = {'nothing': '#FFFF00'})

            down_btn = gr.DownloadButton("Download distance map (.tif)", visible=True)            
            down_btn2 = gr.DownloadButton("Download cell detections (.csv)", visible=True)  


    #sample_list = []
    #for j in range(23):
    #    sample_list.append("samples/img%0.2d.png"%j)
    gr.HTML("""<h4 style="color:white;"> Click on an example to try it:
                    </h4>""")
    
    sample_list = os.listdir("./gradio_examples/jpegs")
    #sample_list = [ ("./gradio_examples/jpegs/"+sample, [((5, 5, 10, 10), 'nothing')]) for sample in sample_list]
    
    print(sample_list)
    sample_list = [ "./gradio_examples/jpegs/"+sample for sample in sample_list]
                
    #gr.Examples(sample_list, fn = update_with_example, inputs=input_image, outputs =  [input_image, up_btn, output_image], examples_per_page=50, label = "Click on an example to try it")
    example_image = gr.Image(visible=False, type='filepath')
    gr.Examples(sample_list, fn= update_with_example, inputs=example_image, outputs=[input_image, up_btn, output_image, down_btn, down_btn2, depth, channel], examples_per_page=5, label=' ', 
                cache_examples=False, run_on_click=True, elem_id='examples')
    #gr.Examples(sample_list, fn= example, inputs=example_image, outputs=[example_image], examples_per_page=5, label = "Click on an example to try it")

    #input_image.upload(update_button, [input_image, depth], [input_image, up_btn, output_image])
    up_btn.upload(update_image, [up_btn, depth], [input_image, up_btn, output_image, down_btn, down_btn2,  depth, channel])
    depth.change(update, [up_btn, down_btn, down_btn2, depth, channel], [input_image, output_image])
    channel.change(update, [up_btn, down_btn, down_btn2, depth, channel], [input_image, output_image])
    #depth.change(update_depth, [up_btn, depth], depth)
    

    
    # Prediction
    send_btn.click(detect_cells, [up_btn, channel, model, rescale_z, rescale_xy], [ down_btn,  down_btn2, output_image]).then(update,  [up_btn, down_btn, down_btn2, depth, channel], [input_image, output_image])# flows, down_btn, down_btn2])

    #down_btn.click(download_function, None, [down_btn, down_btn2])
    
    gr.HTML("""<h4 style="color:white;"> Notes:<br>  </h4>
                    <li style="color:white;">You can load and process 3D tifs in the following dimensions: (T),Z,(C),Y,X. We automatically pick the first timepoint.
                    <li style="color:white;">Without GPU access, cell detection might take ~30 seconds. 
                    <li style="color:white;">Locally OrganoidTracker wil run faster: ~2 seconds per frame on a dedicated GPU, ~10 seconds on a CPU. 
                   """)

    gr.HTML("""<h4 style="color:white;"> Caveats:<br>  </h4>
                    <li style="color:white;">For this demo, an agressive background subtraction step is implemented before prediction, which we find benefits most usecases. For transperency, users have to preprocess the data themselves in OrganoidTracker 2.0. 
                    <li style="color:white;">Because of incompatibilities between TensorFlow and HuggingFace the models here are trained with the upcoming PyTorch version of OrganoidTracker (currently in beta). There might be performance differences when using the TensorFlow-versions presented in our paper.
                   """)
    
    gr.HTML("""<h4 style="color:white;"> References:<br>  </h4>
                    <li style="color:white;">The blastocyst sample data is taken from the BlastoSPIM dataset (Nunley et al., Development, 2024): 
                    <a style="color:#cfe7fe" href="https://blastospim.flatironinstitute.org/html/index1.html" target="_blank">[website]</a>,
                    <a style="color:#cfe7fe" href=https://journals.biologists.com/dev/article/151/21/dev202817/362603/Nuclear-instance-segmentation-and-tracking-for target="_blank">[paper]</a>
                    
                    <li style="color:white;">The c Elegans sample data is taken from the Cell Tracking Challenge (Murray et al., Nature Methods, 2008): 
                    <a style="color:#cfe7fe" href="https://celltrackingchallenge.net/3d-datasets/" target="_blank">[website]</a>,
                    <a style="color:#cfe7fe" href=https://www.nature.com/articles/nmeth.1228 target="_blank">[paper]</a>
            """)
    
    
                    
demo.queue().launch()