Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import torch | |
import gradio as gr | |
from lib import create_model | |
from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group | |
from lib.dataloader import ImageMixin | |
test_weight = './weight_epoch-200_best.pt' | |
parameter = './parameters.json' | |
class ImageHandler(ImageMixin): | |
def __init__(self, params): | |
self.params = params | |
self.transform = self._make_transforms() | |
def set_image(self, image): | |
image = self.transform(image) | |
image = {'image': image.unsqueeze(0)} | |
return image | |
def load_parameter(parameter): | |
_args = ParamSet() | |
params = _retrieve_parameter(parameter) | |
for _param, _arg in params.items(): | |
setattr(_args, _param, _arg) | |
_args.augmentation = 'no' | |
_args.sampler = 'no' | |
_args.pretrained = False | |
_args.mlp = None | |
_args.net = _args.model | |
_args.device = torch.device('cpu') | |
args_model = _dispatch_by_group(_args, 'model') | |
args_dataloader = _dispatch_by_group(_args, 'dataloader') | |
return args_model, args_dataloader | |
args_model, args_dataloader = load_parameter(parameter) | |
model = create_model(args_model) | |
model.load_weight(test_weight) | |
def main(image): | |
model.eval() | |
image_handler = ImageHandler(args_dataloader) | |
image = image_handler.set_image(image) | |
with torch.no_grad(): | |
outputs = model(image) | |
label_name = list(outputs.keys())[0] | |
result = outputs[label_name].detach().numpy().item() | |
result = f"{result:.2f}" | |
return result | |
html_content = """ | |
<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px;"> | |
<h3>Image preprocess</h3> | |
<p>Only grayscale 320×320 resolution works appropriately.</p> | |
<p>The longest side of the Xp should be downscaled to 320 pixels while maintaining the aspect ratio, | |
and the width along the shorter side should be padded black to 320 pixels. | |
</p> | |
<h3>Publication Details</h3> | |
<p>See details in our publication, titled | |
"Chest radiography as a biomarker of ageing: artificial intelligence-based, | |
multi-institutional model development and validation in Japan" | |
</p> | |
<p><strong>Link:</strong> <a href="https://www.thelancet.com/journals/lanhl/article/PIIS2666-7568(23)00133-2/fulltext" target="_blank"> | |
https://www.thelancet.com/journals/lanhl/article/PIIS2666-7568(23)00133-2/fulltext | |
</a></p> | |
</div> | |
""" | |
# Gradio | |
with gr.Blocks(title="Aging Biomarker from CXR", | |
css=".gradio-container {background:mintcream;}" | |
) as demo: | |
gr.HTML("""<div style="text-align:center"><h2>Aging Biomarker from CXR</h2></div>""") | |
gr.HTML(html_content) | |
with gr.Row(): | |
input_image = gr.Image(type="pil", image_mode="L", shape=(320, 320)) | |
output_label=gr.Label(label="Estimated age") | |
send_btn = gr.Button("Inference") | |
send_btn.click(fn=main, inputs=input_image, outputs=output_label) | |
with gr.Row(): | |
gr.Examples(['./samples/66_female_xp.png'], label='Sample CXR 1: 66 years old female', inputs=input_image) | |
gr.Examples(['./samples/28_male_xp.png'], label='Sample CXR 2: 28 years old male', inputs=input_image) | |
demo.launch(debug=True) | |