Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -20,32 +20,32 @@ def normalize_channel(img, lower=1, upper=99):
|
|
20 |
return img_norm.astype(np.uint8)
|
21 |
|
22 |
def predict(filename, model=None, device=None, reduce_labels=True):
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
gr.Interface(
|
50 |
predict,
|
51 |
inputs=[gr.components.Image(label="Upload Input Image", type="filepath"),
|
|
|
20 |
return img_norm.astype(np.uint8)
|
21 |
|
22 |
def predict(filename, model=None, device=None, reduce_labels=True):
|
23 |
+
if img_name.endswith('.tif') or img_name.endswith('.tiff'):
|
24 |
+
img_data = tif.imread(img_name)
|
25 |
+
else:
|
26 |
+
img_data = io.imread(img_name)
|
27 |
+
# normalize image data
|
28 |
+
if len(img_data.shape) == 2:
|
29 |
+
img_data = np.repeat(np.expand_dims(img_data, axis=-1), 3, axis=-1)
|
30 |
+
elif len(img_data.shape) == 3 and img_data.shape[-1] > 3:
|
31 |
+
img_data = img_data[:,:, :3]
|
32 |
+
else:
|
33 |
+
pass
|
34 |
+
pre_img_data = np.zeros(img_data.shape, dtype=np.uint8)
|
35 |
+
for i in range(3):
|
36 |
+
img_channel_i = img_data[:,:,i]
|
37 |
+
if len(img_channel_i[np.nonzero(img_channel_i)])>0:
|
38 |
+
pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=1, upper=99)
|
39 |
+
#dummy_input = np.zeros((512,512,3)).astype(np.uint8)
|
40 |
+
my_model = MultiStreamCellSegModel.from_pretrained("Lewislou/cellseg_sribd")
|
41 |
+
checkpoints = torch.load('model.pt')
|
42 |
+
my_model.__init__(ModelConfig())
|
43 |
+
my_model.load_checkpoints(checkpoints)
|
44 |
+
with torch.no_grad():
|
45 |
+
output = my_model(pre_img_data)
|
46 |
+
overlay = visualize_instances_map(pre_img_data,star_label)
|
47 |
+
#cv2.imwrite('prediction.png', cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
|
48 |
+
return overlay
|
49 |
gr.Interface(
|
50 |
predict,
|
51 |
inputs=[gr.components.Image(label="Upload Input Image", type="filepath"),
|