Lewislou commited on
Commit
02743b2
1 Parent(s): 991881f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -1
app.py CHANGED
@@ -1,3 +1,58 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
2
 
3
- gr.Interface.load("models/Lewislou/cell-seg-sribd").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from skimage import io, segmentation, morphology, measure, exposure
3
+ from sribd_cellseg_models import MultiStreamCellSegModel,ModelConfig
4
+ import numpy as np
5
+ import tifffile as tif
6
+ import requests
7
+ import torch
8
+ from PIL import Image
9
+ from overlay import visualize_instances_map
10
+ import cv2
11
 
12
+
13
+ def normalize_channel(img, lower=1, upper=99):
14
+ non_zero_vals = img[np.nonzero(img)]
15
+ percentiles = np.percentile(non_zero_vals, [lower, upper])
16
+ if percentiles[1] - percentiles[0] > 0.001:
17
+ img_norm = exposure.rescale_intensity(img, in_range=(percentiles[0], percentiles[1]), out_range='uint8')
18
+ else:
19
+ img_norm = img
20
+ return img_norm.astype(np.uint8)
21
+
22
+ def predict(filename, model=None, device=None, reduce_labels=True):
23
+ if img_name.endswith('.tif') or img_name.endswith('.tiff'):
24
+ img_data = tif.imread(img_name)
25
+ else:
26
+ img_data = io.imread(img_name)
27
+ # normalize image data
28
+ if len(img_data.shape) == 2:
29
+ img_data = np.repeat(np.expand_dims(img_data, axis=-1), 3, axis=-1)
30
+ elif len(img_data.shape) == 3 and img_data.shape[-1] > 3:
31
+ img_data = img_data[:,:, :3]
32
+ else:
33
+ pass
34
+ pre_img_data = np.zeros(img_data.shape, dtype=np.uint8)
35
+ for i in range(3):
36
+ img_channel_i = img_data[:,:,i]
37
+ if len(img_channel_i[np.nonzero(img_channel_i)])>0:
38
+ pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=1, upper=99)
39
+ #dummy_input = np.zeros((512,512,3)).astype(np.uint8)
40
+ my_model = MultiStreamCellSegModel.from_pretrained("Lewislou/cellseg_sribd")
41
+ checkpoints = torch.load('model.pt')
42
+ my_model.__init__(ModelConfig())
43
+ my_model.load_checkpoints(checkpoints)
44
+ with torch.no_grad():
45
+ output = my_model(pre_img_data)
46
+ overlay = visualize_instances_map(pre_img_data,star_label)
47
+ #cv2.imwrite('prediction.png', cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
48
+ return overlay
49
+ gr.Interface(
50
+ predict,
51
+ inputs=[gr.components.Image(label="Upload Input Image", type="filepath"),
52
+ gr.components.Textbox(label='Model Name', value='sribd_med', max_lines=1)],
53
+ outputs=[gr.Image(label="Processed Image"),
54
+ gr.Image(label="Label Image"),
55
+ ],
56
+ title="Cell Segmentation Results",
57
+ examples=get_examples(default_model)
58
+ ).launch()