fadindashfr
initial commit all file
4c5329d
raw
history blame
3.24 kB
import torch
from monai.bundle import ConfigParser
import gradio as gr
parser = ConfigParser() # load configuration files that specify various parameters for running the MONAI workflow.
parser.read_config(f="configs/inference.json") # read the config from specified JSON file
parser.read_meta(f="configs/metadata.json") # read the metadata from specified JSON file
inference = parser.get_parsed_content("inferer")
network = parser.get_parsed_content("network_def")
preprocess = parser.get_parsed_content("preprocessing")
state_dict = torch.load("models/model.pt")
network.load_state_dict(state_dict, strict=True) # Loads a model’s parameter dictionary
class_names = {
0: "Other",
1: "Inflammatory",
2: "Epithelial",
3: "Spindle-Shaped",
}
def classify_image(image_file, label_file):
data = {"image":image_file, "label":label_file}
batch = preprocess(data)
network.eval()
with torch.no_grad():
pred = inference(batch['image'].unsqueeze(dim=0), network) # expect 4 channels input (3 RGB, 1 Label mask)
prob = pred.softmax(-1).detach().cpu().numpy()[0]
confidences = {class_names[i]: float(prob[i]) for i in range(len(class_names))}
return confidences
example_files1 = [
[r'sample_data\Images\test_11_2_0628.png',
r'sample_data\Labels\test_11_2_0628.png'],
[r'sample_data\Images\test_9_4_0149.png',
r'sample_data\Labels\test_9_4_0149.png'],
[r'sample_data\Images\test_12_3_0292.png',
r'sample_data\Labels\test_12_3_0292.png'],
[r'sample_data\Images\test_9_4_0019.png',
r'sample_data\Labels\test_9_4_0019.png']
]
example_files2 = [
[r'sample_data\Images\test_14_3_0433.png',
r'sample_data\Labels\test_14_3_0433.png'],
[r'sample_data\Images\test_14_4_0544.png',
r'sample_data\Labels\test_14_4_0544.png'],
[r'sample_data\Images\train_1_1_0095.png',
r'sample_data\Labels\train_1_1_0095.png'],
[r'sample_data\Images\train_1_3_0020.png',
r'sample_data\Labels\train_1_3_0020.png'],
]
with open('Description.md','r') as file:
markdown_content = file.read()
with gr.Blocks() as app:
gr.Markdown("# Pathology Nuclei Classification")
gr.Markdown(markdown_content)
with gr.Row():
with gr.Column():
with gr.Row():
inp_img = gr.Image(type="filepath", image_mode="RGB")
label_img = gr.Image(type="filepath", image_mode="L")
with gr.Row():
process_btn = gr.Button(value="Process")
clear_btn = gr.Button(value="Clear")
out_txt = gr.Label(label="Probabilities", num_top_classes=4)
process_btn.click(fn=classify_image, inputs=[inp_img, label_img], outputs=out_txt)
clear_btn.click(lambda:(
gr.update(value=None),
gr.update(value=None),
gr.update(value=None)
),
inputs=None,
outputs=[inp_img, label_img,out_txt]
)
gr.Markdown("## Image Examples")
with gr.Row():
for file in example_files1:
gr.Examples(
[file], inputs=[inp_img, label_img]
)
with gr.Row():
for file in example_files2:
gr.Examples(
[file], inputs=[inp_img, label_img]
)
app.launch()