File size: 11,638 Bytes
49c6db7
 
8786ac3
49c6db7
 
 
f0821bf
 
44c9541
e16c5c4
3e66137
47accba
 
3e66137
47accba
266f336
47accba
 
465186b
 
3e66137
 
d874e72
3e66137
 
 
 
 
2fefa26
 
 
 
 
 
 
d51d5c4
2fefa26
 
 
 
 
 
3e66137
49c6db7
ef2c520
 
49c6db7
 
 
 
3e66137
 
 
49c6db7
 
 
 
 
 
 
 
 
 
f0821bf
fd26ead
49c6db7
 
 
 
 
 
 
 
f0821bf
 
 
 
 
 
 
 
 
 
1c5339e
f0821bf
 
 
 
c034f55
 
 
 
 
 
 
 
 
 
 
 
 
736285e
c034f55
ea7f537
 
 
49c6db7
 
bf308b6
 
 
 
 
 
49c6db7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32b0ba4
49c6db7
 
 
 
 
 
 
 
 
 
 
 
 
eeb07d9
ded4361
a887322
c994db7
a887322
 
c08aa90
 
c994db7
c08aa90
 
 
 
c994db7
c08aa90
 
 
 
c994db7
c08aa90
 
f9b1b13
0b8b42c
5ba71fa
4bc8e38
 
 
8d229b1
4bc8e38
 
 
 
 
 
 
 
 
 
 
 
 
c08aa90
4bc8e38
 
 
 
5ba71fa
4bc8e38
 
5ba71fa
4bc8e38
 
5ba71fa
a887322
49c6db7
 
 
18b4441
8a8ccfd
5ba71fa
958ea27
5ba71fa
18b4441
 
 
4c3c584
a08adb1
eecd9f2
7902217
 
 
 
5ba71fa
7902217
 
5ba71fa
e16c5c4
7170f20
4bc8e38
 
 
 
e16c5c4
4c3c584
80e25f6
49c6db7
 
2c27168
 
 
 
 
 
 
49c6db7
fd1e2f9
25641bf
 
 
 
2d8800e
fcdd787
5261ab2
 
93d755b
5261ab2
 
cb2e681
5261ab2
 
e273c25
 
5ba71fa
93d755b
5bf1496
 
4bde3b4
 
 
5bf1496
 
 
fcdd787
2d8800e
2c27168
 
 
c77528c
2c27168
c3abe48
95242a5
 
 
 
03ae964
 
2e35a3d
1fe72e5
 
22bab81
741a210
22bab81
f43ddd6
daeff40
e19ebf3
5261ab2
03ae964
16b034a
f43ddd6
c0e39ef
 
7798670
 
18b4441
03ae964
f43ddd6
 
 
7902217
4d163cf
 
 
 
 
a6f1288
b2eef14
fa3e72c
fcdd787
4bde3b4
a3dc09f
25641bf
2c27168
157bb22
 
 
 
2c27168
fe58ba1
 
 
2c27168
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
import numpy as np
import gradio as gr
import spaces
import cv2
from cellpose import models
from matplotlib.colors import hsv_to_rgb
import matplotlib.pyplot as plt
import os, io, base64
from PIL import Image 
from cellpose.io import imread, imsave

from huggingface_hub import hf_hub_download

# @title Data retrieval
def download_weights():    
    return hf_hub_download(repo_id="mouseland/cellpose-sam", filename="cpsam")
    
    #os.system("wget -q https://huggingface.co/mouseland/cellpose-sam/resolve/main/cpsam")

def download_weights_old():
    import os, requests
    
    fname = ['cpsam']
    
    url = ["https://osf.io/d7c8e/download"]
    
    for j in range(len(url)):
      if not os.path.isfile(fname[j]):
        ntries = 0
        while ntries<10:
            try:
              r = requests.get(url[j])
            except:
                print("!!! Failed to download data !!!")
                ntries += 1 
                print(ntries)
            
      if r.status_code != requests.codes.ok:
        print("!!! Failed to download data !!!")
      else:
        with open(fname[j], "wb") as fid:
          fid.write(r.content)

try:
    #fpath = download_weights()
    model = models.CellposeModel(gpu=True)# , pretrained_model=fpath)
except Exception as e:
    print(f"Error loading model: {e}")
    exit(1)



            
def plot_flows(y):
    Y = (np.clip(normalize99(y[0][0]),0,1) - 0.5) * 2
    X = (np.clip(normalize99(y[1][0]),0,1) - 0.5) * 2
    H = (np.arctan2(Y, X) + np.pi) / (2*np.pi)
    S = normalize99(y[0][0]**2 + y[1][0]**2)
    HSV = np.concatenate((H[:,:,np.newaxis], S[:,:,np.newaxis], S[:,:,np.newaxis]), axis=-1)
    HSV = np.clip(HSV, 0.0, 1.0)
    flow = (hsv_to_rgb(HSV) * 255).astype(np.uint8)
    return flow

def plot_outlines(img, masks):
    img = normalize99(img)
    outpix = []
    contours, hierarchy = cv2.findContours(masks.astype(np.int32), mode=cv2.RETR_FLOODFILL, method=cv2.CHAIN_APPROX_SIMPLE)
    for c in range(len(contours)):
        pix = contours[c].astype(int).squeeze()
        if len(pix)>4:
            peri = cv2.arcLength(contours[c], True)
            approx = cv2.approxPolyDP(contours[c], 0.001, True)[:,0,:]
            outpix.append(approx)
    
    figsize = (6,6)
    if img.shape[0]>img.shape[1]:
        figsize = (6*img.shape[1]/img.shape[0], 6)
    else:
        figsize = (6, 6*img.shape[0]/img.shape[1])
    fig = plt.figure(figsize=figsize, facecolor='k')
    ax = fig.add_axes([0.0,0.0,1,1])
    ax.set_xlim([0,img.shape[1]])
    ax.set_ylim([0,img.shape[0]])
    ax.imshow(img[::-1], origin='upper', aspect = 'auto')
    if outpix is not None:
        for o in outpix:
            ax.plot(o[:,0], img.shape[0]-o[:,1], color=[1,0,0], lw=1)
    ax.axis('off')
    
    #bytes_image = io.BytesIO()
    #plt.savefig(bytes_image, format='png', facecolor=fig.get_facecolor(), edgecolor='none')
    #bytes_image.seek(0)
    #img_arr = np.frombuffer(bytes_image.getvalue(), dtype=np.uint8)
    #bytes_image.close()
    #img = cv2.imdecode(img_arr, 1)
    #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    #del bytes_image
    #fig.clf()
    #plt.close(fig)

    buf = io.BytesIO()
    fig.savefig(buf, bbox_inches='tight')
    buf.seek(0)
    pil_img = Image.open(buf)

    return pil_img

def plot_overlay(img, masks):
    if img.ndim>2:
        img_gray = img.astype(np.float32).mean(axis=-1)
    else:
        img_gray = img.astype(np.float32)
        
    img = normalize99(img_gray)
    img -= img.min()
    img /= img.max()
    HSV = np.zeros((img.shape[0], img.shape[1], 3), np.float32)
    HSV[:,:,2] = np.clip(img*1.5, 0, 1.0)
    for n in range(int(masks.max())):
        ipix = (masks==n+1).nonzero()
        HSV[ipix[0],ipix[1],0] = np.random.rand()
        HSV[ipix[0],ipix[1],1] = 1.0
    RGB = (hsv_to_rgb(HSV) * 255).astype(np.uint8)
    return RGB

def normalize99(img):
    X = img.copy()
    X = (X - np.percentile(X, 1)) / (np.percentile(X, 99) - np.percentile(X, 1))
    return X

def image_resize(img, resize=400):
    ny,nx = img.shape[:2]
    if np.array(img.shape).max() > resize:
        if ny>nx:
            nx = int(nx/ny * resize)
            ny = resize
        else:
            ny = int(ny/nx * resize)
            nx = resize
        shape = (nx,ny)
        img = cv2.resize(img, shape)
    img = img.astype(np.uint8)
    return img

    
@spaces.GPU(duration=10)
def run_model_gpu(img):
    masks, flows, _ = model.eval(img)#, channels = [0,0])
    return masks, flows

@spaces.GPU(duration=60)
def run_model_gpu60(img):
    masks, flows, _ = model.eval(img)#, channels = [0,0])
    return masks, flows

@spaces.GPU(duration=240)
def run_model_gpu240(img):
    masks, flows, _ = model.eval(img)#, channels = [0,0])
    return masks, flows

@spaces.GPU(duration=1000)
def run_model_gpu1000(img):
    masks, flows, _ = model.eval(img)#, channels = [0,0])
    return masks, flows

from zipfile import ZipFile
def cellpose_segment(filepath, resize = 1000):

    zip_path = 'masks.zip'
    with ZipFile(zip_path, 'w') as myzip:
        for j in range((len(filepath))):
            print(j)
            img_input = imread(filepath[j])
            #img_input = np.array(img_pil)
            img = image_resize(img_input, resize = resize)
            
            resize = np.max(img.shape)
            if resize<1000:
                masks, flows = run_model_gpu(img)
            elif resize < 5000:
                masks, flows = run_model_gpu60(img)
            elif resize < 20000:
                masks, flows = run_model_gpu240(img)
            else:
                raise ValueError("Image size must be less than 20,000")
        
            target_size = (img_input.shape[1], img_input.shape[0])
            if (target_size[0]!=img.shape[1] or target_size[1]!=img.shape[0]):
                # scale it back to keep the orignal size
                masks = cv2.resize(masks.astype('uint16'), target_size, interpolation=cv2.INTER_NEAREST).astype('uint16')
    
            fname_masks = os.path.splitext(filepath[j])[0]+"_masks.tif"
            imsave(fname_masks, masks)
    
            myzip.write(fname_masks)
            
    
    #masks, flows, _ = model.eval(img, channels=[0,0])
    flows = flows[0]
    # masks = np.zeros(img.shape[:2])
    # flows = np.zeros_like(img)

    outpix = plot_outlines(img, masks)
    #overlay = plot_overlay(img, masks)
    
        
    
    #crand = .2 + .8 * np.random.rand(np.max(masks.flatten()).astype('int')+1,).astype('float32')
    #crand[0] = 0

    #overlay = Image.fromarray(overlay)
    flows = Image.fromarray(flows)

    Ly, Lx = img.shape[:2]
    c = Lx
    outpix = outpix.resize((Lx, Ly), resample  = Image.BICUBIC)
    #overlay = overlay.resize((Lx, Ly), resample  = Image.BICUBIC)
    flows = flows.resize((Lx, Ly), resample  = Image.BICUBIC)

    fname_out  = os.path.splitext(filepath[-1])[0]+"_outlines.png"
    outpix.save(fname_out) #"outlines.png")

    if len(filepath)>1:
        b1 = gr.DownloadButton(visible=True, value = zip_path)
    else:
        b1 = gr.DownloadButton(visible=True, value = fname_masks)
    b2 = gr.DownloadButton(visible=True, value = fname_out) #"outlines.png")
    
    return outpix, flows, b1, b2

# Gradio Interface
#iface = gr.Interface(
#    fn=cellpose_segment, 
#    inputs="image", 
#    outputs=["image", "image", "image", "image"],
#    title="cellpose segmentation",
#    description="upload an image, then cellpose will segment it at a max size of 400x400 (for full functionality, 'pip install cellpose' locally)"
#)

def download_function(): 
    b1 = gr.DownloadButton("Download masks as TIFF", visible=False)
    b2 = gr.DownloadButton("Download outline image as PNG", visible=False)
    return b1, b2

def upload_file(filepath): 
    #img = imread(filepath)    
    #img = normalize99(img)
    #img = np.clip(img, 0, 1)

    #filegui  = os.path.splitext(filepath)[0]+"_gui.png"
    #imsave(filegui, img)

    #b1 = gr.DownloadButton("Download masks as TIFF", visible=False)
    #b2 = gr.DownloadButton("Download outline image as PNG", visible=False)
    for f in filepath: 
        print(f)
    return filepath[-1] #, b1, b2

 
def update_image(filepath): 
    #img = imread(filepath)    
    #img = normalize99(img)
    #img = np.clip(img, 0, 1)

    #b1 = gr.DownloadButton("Download masks as TIFF", visible=False)
    #b2 = gr.DownloadButton("Download outline image as PNG", visible=False)
    return [filepath]#, b1, b2  
    
with gr.Blocks(title = "Hello", 
               css=".gradio-container {background:purple;}") as demo:

    #filepath = ""
    with gr.Row():
        with gr.Column(scale=2):
            gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:20pt; font-weight:bold; text-align:center; color:white;">Cellpose-SAM for cellular 
            segmentation <a style="color:#cfe7fe; font-size:14pt;" href="https://github.com/MouseLand/cellpose" target="_blank">[paper]</a> 
            <a style="color:white; font-size:14pt;" href="https://github.com/MouseLand/cellpose" target="_blank">[github]</a>
            </div>""")
            gr.HTML("""<h4 style="color:white;">You may need to login/refresh for 5 minutes of free GPU compute per day (enough to process hundreds of images). </h4>""")
            
            input_image = gr.Image(label = "Input", type = "filepath")

            with gr.Row():
                with gr.Column(scale=1):                    
                    up_btn = gr.UploadButton("Multi-file upload (png, jpg, tif etc)", visible=True, file_count = "multiple")
                    resize = gr.Number(label = 'max resize', value = 1000)                
                    
                    #gr.HTML("""<h4 style="color:white;"> Note2: Only the first image of a tif will display the segmentations, but you can download segmentations for all planes. </h4>""")
                    
                    
                
                with gr.Column(scale=1):
                    send_btn = gr.Button("Run Cellpose-SAM")
                    down_btn = gr.DownloadButton("Download masks (TIF)", visible=False)            
                    down_btn2 = gr.DownloadButton("Download outlines (PNG)", visible=False)  
            
            

                    
        with gr.Column(scale=2):     
            img_outlines = gr.Image(label = "Outlines", type = "pil", format = 'png') #, width = "50vw", height = "20vw")
            #img_overlay = gr.Image(label = "Overlay", type = "pil", format = 'png') #, width = "50vw", height = "20vw")
            flows = gr.Image(label = "Cellpose flows", type = "pil", format = 'png') #, width = "50vw", height = "20vw")

            
    sample_list = []
    for j in range(23):
        sample_list.append("samples/img%0.2d.png"%j)
    gr.Examples(sample_list, inputs=input_image, examples_per_page=25, label = "Click on an example to try it")
    
    input_image.upload(update_image, input_image, up_btn)
    up_btn.upload(upload_file, up_btn, input_image)
    send_btn.click(cellpose_segment, [up_btn, resize], [img_outlines, flows, down_btn, down_btn2])

    #down_btn.click(download_function, None, [down_btn, down_btn2])
        
    gr.HTML("""<h4 style="color:white;"> Notes:<br> 
                    <li>you can load and process single-image tifs, but they won't display in the input field above. 
                    <li>install Cellpose-SAM locally for full functionality.
                    </h4>""")
    
    # <li>the smallest dimension of a tif --> channels 
    # <li>you can load multiple files and download a zip of the segmentations  
                    
demo.launch()