monai-demo / app.py
Daniel Nouri
Initial commit: Hugging Face Hub MONAI integration demo
5620f04
from glob import glob
import logging
from matplotlib import pyplot as plt
import os
from monai.bundle.scripts import upload_zoo_bundle_to_hf
import streamlit as st
import torch
def main():
st.title("MONAI 🤗 Hugging Face Integration")
st.write("""\
Here's a demo of a prototype integration between
[MONAI](https://monai.io/) and the [Hugging Face
Hub](https://huggingface.co/docs/hub/index), which allows for
uploading models to the Hub and downloading them. The integration
itself is implemented in [this
branch](https://github.com/dnouri/MONAI/tree/dnouri/huggingface-support)
of MONAI.
""")
st.write("""\
## Uploading models to the Hub ⬆
The new `upload_zoo_bundle_to_hf` command allows us to upload models
from the existing [MONAI Model
Zoo](https://github.com/Project-MONAI/model-zoo) on Github directly
onto the Hugging Face Hub.
The `--name` option specifies the [filename of an existing
model](https://github.com/Project-MONAI/model-zoo/releases/tag/hosting_storage_v1)
in the MONAI Model Zoo, while the `--hf_organization` specifies the
name of the organization to upload to, in the Hugging Face Hub,
whereas `--hf_token` is the [HF user access
token](https://huggingface.co/docs/hub/security-tokens).
An additional `--hf_card_data` option allows us to specify [model card
metadata](https://huggingface.co/docs/hub/models-cards#model-card-metadata)
to be added to the Hugging Face model card.
An example call to the `upload_zoo_bundle_to_hf` script looks like
this:
```bash
python -m monai.bundle upload_zoo_bundle_to_hf \\
--name spleen_ct_segmentation_v0.1.0 \\
--hf_organization dnouri --hf_token mytoken \\
--hf_card_data '{"lang": "en"}'
```
An example of a thus automatically uploaded model can be found
[here](https://huggingface.co/dnouri/spleen_ct_segmentation).
### Try it out!
To try out uploading your own model, please provide the information below:
""")
filename = st.text_input("Filename of MONAI Model Zoo model "
"(e.g. ventricular_short_axis_3label_v0.1.0.zip)")
username = st.text_input("Hub organization or user name (e.g. dnouri)")
card_data = st.text_input("Optional model card metadata",
value='{"tags": ["MONAI"]}')
token = st.text_input("Hugging Face user access token")
if filename and username and token:
st.write("Please wait...")
upload_zoo_bundle_to_hf(
name=filename,
hf_organization=username,
hf_token=token,
hf_card_data=card_data or None,
)
st.write(f"""\
Done! You should be able to find the [result here](https://huggingface.co/{username}/{filename.rsplit("_", 1)[0]}).
""")
st.write("""\
## Downloading models from the Hub ⬇
Uploading isn't much fun if you can't also download the models from
the Hub! To help with that, we've added support for the Hugging Face
Hub to the existing MONAI bundle `download` command.
The `download` command's default `--source` is `github`. We'll choose
`huggingface` instead to download from the Hub.
The `--name` of the model is the name of your model on the Hub,
e.g. `ventricular_short_axis_3label`. Note that as per MONAI
convention, we do not specify the version name here. (Future versions of
this command might allow for downloading specific versions, or tags.)
The `--repo` normally points to the MONAI Model Zoo's ['hosting
storage' release page on
Github](https://github.com/Project-MONAI/model-zoo/releases/tag/hosting_storage_v1).
When we call `download` with the `huggingface` source, we'll require
the `--repo` argument to point to the organization or user name that
hosts the model, e.g. `dnouri`. (While this choice is a bit
confusing, it also reflects an attempt to pragmatically blend concepts
from both MONAI bundles and the Hub. Future versions might improve on
this.)
An example call to the `upload_zoo_bundle_to_hf` script that perhaps
downloads the model that we uploaded previously, looks like this:
```bash
python -m monai.bundle download \\
--name spleen_ct_segmentation \\
--source huggingface --repo dnouri
```
""")
st.write("""\
## Use model for inference 🧠
To use the `spleen_ct_segmentation` pretrained model to do inference,
we'll first load it into memory (as a TorchScript module) using the
`load` function below. This will download the model from the Hugging
Face Hub, as `load` uses the aforementioned `download` under the hood:
""")
# The next line is a workaround against a buggy interaction
# between how streamlit sets up stderr and how tqdm uses it:
logging.getLogger().setLevel(logging.NOTSET)
with st.echo():
from monai.bundle.scripts import load
model, metadata, extra = load(
name="spleen_ct_segmentation",
source="huggingface",
repo="dnouri",
load_ts_module=True,
progress=False,
)
st.write("""\
This will produce a model, but we'll also need the corresponding
transforms. These are defined in the MONAI bundle configuration
files. There's unfortunately not a convenient way to do this using a
MONAI bundle script function, so we'll have to reach into the MONAI
bowels for a bit:
""")
with st.echo():
from monai.bundle.config_parser import ConfigParser
from monai.bundle.scripts import _process_bundle_dir
model_dir = _process_bundle_dir() / "spleen_ct_segmentation"
config_paths = [
model_dir / "configs" / "train.json",
model_dir / "configs" / "evaluate.json",
]
config = ConfigParser(
ConfigParser.load_config_files(config_paths),
)
preprocess = config.get_parsed_content("validate#preprocessing")
st.write("""\
We'll borrow code from the MONAI [Spleen 3D segmentation with MONAI
tutorial](https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/spleen_segmentation_3d.ipynb)
to download the data that our `spleen_ct_segmentation` model was
trained with:
""")
with st.echo():
from monai.apps import download_and_extract
root_dir = os.environ.get(
"MONAI_DATA_DIRECTORY",
os.path.expanduser("~/.cache/monai_data_directory")
)
os.makedirs(root_dir, exist_ok=True)
resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar"
md5 = "410d4a301da4e5b2f6f86ec3ddba524e"
compressed_file = os.path.join(root_dir, "Task09_Spleen.tar")
data_dir = os.path.join(root_dir, "Task09_Spleen")
if not os.path.exists(data_dir):
download_and_extract(resource, compressed_file, root_dir, md5)
train_images = sorted(
glob(os.path.join(data_dir, "imagesTr", "*.nii.gz")))
train_labels = sorted(
glob(os.path.join(data_dir, "labelsTr", "*.nii.gz")))
data_dicts = [
{"image": image_name, "label": label_name}
for image_name, label_name in zip(train_images, train_labels)
]
files = data_dicts
st.write(f"Downloaded {len(files)} files.")
st.write("""\
Finally, we can run inference and plot some results: 🥳
""")
image_idx = st.slider("Image number", 0, len(files))
with st.echo():
from monai.inferers import sliding_window_inference
data = preprocess(files[image_idx])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
output = sliding_window_inference(
inputs=data["image"].to(device)[None, ...],
roi_size=(160, 160, 160),
sw_batch_size=4,
predictor=model.eval(),
)
fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
ax1.set_title("Image")
ax2.set_title("Label")
ax3.set_title("Output")
ax1.imshow(data["image"][0, :, :, 80], cmap="gray")
ax2.imshow(data["label"][0, :, :, 80], cmap="gray")
output_img = (
torch.argmax(output, dim=1)[0, :, :, 80]
.cpu().detach().numpy()
)
ax3.imshow(
output_img,
cmap="gray",
)
st.pyplot(fig)
if __name__ == "__main__":
main()