deepflash2 / app_onnx.py
matjesg's picture
Create new file
5a6b006
raw
history blame contribute delete
No virus
1.91 kB
import numpy as np
import gradio as gr
import onnxruntime as ort
from matplotlib import pyplot as plt
from huggingface_hub import hf_hub_download
def create_model_for_provider(model_path, provider="CPUExecutionProvider"):
options = ort.SessionOptions()
options.intra_op_num_threads = 1
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session = ort.InferenceSession(str(model_path), options, providers=[provider])
session.disable_fallback()
return session
def inference(repo_id, model_name, img):
model = hf_hub_download(repo_id=repo_id, filename=model_name)
ort_session = create_model_for_provider(model)
n_channels = ort_session.get_inputs()[0].shape[-1]
img = img[...,:n_channels]/255
ort_inputs = {ort_session.get_inputs()[0].name: img.astype(np.float32)}
ort_outs = ort_session.run(None, ort_inputs)
return ort_outs[0]*255, ort_outs[2]/0.25
title="deepflash2"
description='deepflash2 is a deep-learning pipeline for the segmentation of ambiguous microscopic images.\n deepflash2 uses deep model ensembles to achieve more accurate and reliable results. Thus, inference time will be more than a minute in this space.'
examples=[['matjesg/deepflash2_demo', 'cFOS_ensemble.onnx', 'cFOS_example.png'],
['matjesg/deepflash2_demo', 'YFP_ensemble.onnx', 'YFP_example.png']
]
gr.Interface(inference,
[gr.inputs.Textbox(placeholder='e.g., matjesg/cFOS_in_HC', label='repo_id'),
gr.inputs.Textbox(placeholder='e.g., ensemble.onnx', label='model_name'),
gr.inputs.Image(type='numpy', label='Input image')
],
[gr.outputs.Image(label='Segmentation Mask'),
gr.outputs.Image(label='Uncertainty Map')],
title=title,
description=description,
examples=examples,
).launch()