matjesg commited on
Commit
70aa6d4
1 Parent(s): b345345

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -14
app.py CHANGED
@@ -1,11 +1,10 @@
1
  import numpy as np
2
  import gradio as gr
 
3
  import onnxruntime as ort
4
  from matplotlib import pyplot as plt
5
  from huggingface_hub import hf_hub_download
6
 
7
- model = hf_hub_download(repo_id="matjesg/cFOS_in_HC", filename="ensemble.onnx")
8
-
9
  def create_model_for_provider(model_path, provider="CPUExecutionProvider"):
10
  options = ort.SessionOptions()
11
  options.intra_op_num_threads = 1
@@ -14,26 +13,29 @@ def create_model_for_provider(model_path, provider="CPUExecutionProvider"):
14
  session.disable_fallback()
15
  return session
16
 
17
- ort_session = create_model_for_provider(model)
18
-
19
- def inference(img):
20
-
21
- img = img[...,:1]/255
22
 
 
23
  ort_inputs = {ort_session.get_inputs()[0].name: img.astype(np.float32)}
24
 
25
  ort_outs = ort_session.run(None, ort_inputs)
26
 
27
- return ort_outs[0]*255
28
-
29
-
30
  title="deepflash2"
31
- description="deepflash2 is a deep-learning pipeline for segmentation of ambiguous microscopic images."
32
- examples=[['cFOS_example.png']]
33
 
34
  gr.Interface(inference,
35
- gr.inputs.Image(type="numpy"),
36
- gr.outputs.Image(),
 
 
 
 
37
  title=title,
38
  description=description,
39
  examples=examples
 
1
  import numpy as np
2
  import gradio as gr
3
+ import imageio.v2 as imageio
4
  import onnxruntime as ort
5
  from matplotlib import pyplot as plt
6
  from huggingface_hub import hf_hub_download
7
 
 
 
8
  def create_model_for_provider(model_path, provider="CPUExecutionProvider"):
9
  options = ort.SessionOptions()
10
  options.intra_op_num_threads = 1
 
13
  session.disable_fallback()
14
  return session
15
 
16
+ def inference(repo_id, model_name, img):
17
+ model = hf_hub_download(repo_id=repo_id, filename=model_name)
18
+ ort_session = create_model_for_provider(model)
19
+ n_channels = ort_session.get_inputs()[0].shape[-1]
 
20
 
21
+ img = img[...,:n_channels]/255
22
  ort_inputs = {ort_session.get_inputs()[0].name: img.astype(np.float32)}
23
 
24
  ort_outs = ort_session.run(None, ort_inputs)
25
 
26
+ return ort_outs[0]*255, ort_outs[2]/0.25
27
+
 
28
  title="deepflash2"
29
+ description='deepflash2 is a deep-learning pipeline for the segmentation of ambiguous microscopic images.\n deepflash2 uses deep model ensembles to achieve more accurate and reliable results. Thus, inference time will be more than a minute in this space.'
30
+ examples=[['matjesg/cFOS_in_HC', 'ensemble.onnx', '0001_cFOS.png']]
31
 
32
  gr.Interface(inference,
33
+ [gr.inputs.Textbox(placeholder='e.g., matjesg/cFOS_in_HC', label='repo_id'),
34
+ gr.inputs.Textbox(placeholder='e.g., ensemble.onnx', label='model_name'),
35
+ gr.inputs.Image(type='numpy', label='Input image')
36
+ ],
37
+ [gr.outputs.Image(label='Segmentation Mask'),
38
+ gr.outputs.Image(label='Uncertainty Map')],
39
  title=title,
40
  description=description,
41
  examples=examples