Lewislou commited on
Commit
a41d6d8
1 Parent(s): 0cdd7e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -26
app.py CHANGED
@@ -20,32 +20,32 @@ def normalize_channel(img, lower=1, upper=99):
20
  return img_norm.astype(np.uint8)
21
 
22
  def predict(filename, model=None, device=None, reduce_labels=True):
23
- if img_name.endswith('.tif') or img_name.endswith('.tiff'):
24
- img_data = tif.imread(img_name)
25
- else:
26
- img_data = io.imread(img_name)
27
- # normalize image data
28
- if len(img_data.shape) == 2:
29
- img_data = np.repeat(np.expand_dims(img_data, axis=-1), 3, axis=-1)
30
- elif len(img_data.shape) == 3 and img_data.shape[-1] > 3:
31
- img_data = img_data[:,:, :3]
32
- else:
33
- pass
34
- pre_img_data = np.zeros(img_data.shape, dtype=np.uint8)
35
- for i in range(3):
36
- img_channel_i = img_data[:,:,i]
37
- if len(img_channel_i[np.nonzero(img_channel_i)])>0:
38
- pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=1, upper=99)
39
- #dummy_input = np.zeros((512,512,3)).astype(np.uint8)
40
- my_model = MultiStreamCellSegModel.from_pretrained("Lewislou/cellseg_sribd")
41
- checkpoints = torch.load('model.pt')
42
- my_model.__init__(ModelConfig())
43
- my_model.load_checkpoints(checkpoints)
44
- with torch.no_grad():
45
- output = my_model(pre_img_data)
46
- overlay = visualize_instances_map(pre_img_data,star_label)
47
- #cv2.imwrite('prediction.png', cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
48
- return overlay
49
  gr.Interface(
50
  predict,
51
  inputs=[gr.components.Image(label="Upload Input Image", type="filepath"),
 
20
  return img_norm.astype(np.uint8)
21
 
22
  def predict(filename, model=None, device=None, reduce_labels=True):
23
+ if img_name.endswith('.tif') or img_name.endswith('.tiff'):
24
+ img_data = tif.imread(img_name)
25
+ else:
26
+ img_data = io.imread(img_name)
27
+ # normalize image data
28
+ if len(img_data.shape) == 2:
29
+ img_data = np.repeat(np.expand_dims(img_data, axis=-1), 3, axis=-1)
30
+ elif len(img_data.shape) == 3 and img_data.shape[-1] > 3:
31
+ img_data = img_data[:,:, :3]
32
+ else:
33
+ pass
34
+ pre_img_data = np.zeros(img_data.shape, dtype=np.uint8)
35
+ for i in range(3):
36
+ img_channel_i = img_data[:,:,i]
37
+ if len(img_channel_i[np.nonzero(img_channel_i)])>0:
38
+ pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=1, upper=99)
39
+ #dummy_input = np.zeros((512,512,3)).astype(np.uint8)
40
+ my_model = MultiStreamCellSegModel.from_pretrained("Lewislou/cellseg_sribd")
41
+ checkpoints = torch.load('model.pt')
42
+ my_model.__init__(ModelConfig())
43
+ my_model.load_checkpoints(checkpoints)
44
+ with torch.no_grad():
45
+ output = my_model(pre_img_data)
46
+ overlay = visualize_instances_map(pre_img_data,star_label)
47
+ #cv2.imwrite('prediction.png', cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
48
+ return overlay
49
  gr.Interface(
50
  predict,
51
  inputs=[gr.components.Image(label="Upload Input Image", type="filepath"),