|
import streamlit as st |
|
import json |
|
import numpy as np |
|
import nibabel as nib |
|
import torch |
|
import scipy.io |
|
from io import BytesIO |
|
from transformers import AutoModel |
|
import os |
|
import tempfile |
|
from pathlib import Path |
|
import pandas as pd |
|
|
|
|
|
st.set_page_config( |
|
page_title="DiffAE3D | Cardiac MRI Phenotyping", |
|
page_icon="🧠", |
|
layout="wide", |
|
initial_sidebar_state="expanded", |
|
) |
|
|
|
|
|
with st.sidebar: |
|
st.title("Obtaining unsupervised phenotypes from cardiac MRIs: UK Biobank, 20208 | DiffAE3D") |
|
st.markdown(""" |
|
This application allows you to upload a 3D NIfTI file (dims: time x H x W), process it through a pre-trained 3D DiffAE model, and download the output as a `.json` or `.csv` file containing 128 latent factors. |
|
|
|
**Instructions**: |
|
- Upload your 3D NIfTI file (`.nii` or `.nii.gz`). It should be a single-slice cardiac long-axis dynamic CINE scan, where the first dimension represents time. |
|
- Select a seed value from the dropdown menu. |
|
- Click the "Process" button to generate the latent factors. |
|
""") |
|
st.markdown("---") |
|
st.markdown("© 2024 Soumick Chatterjee | Glastonbury Group | Human Technopole") |
|
|
|
|
|
st.header("From single-slice cardiac long-axis dynamic CINE scan (3D: TxHxW) to 128 latent factors...") |
|
|
|
|
|
uploaded_file = st.file_uploader( |
|
"Please upload a 3D NIfTI file (.nii or .nii.gz)", |
|
type=["nii", "nii.gz"] |
|
) |
|
|
|
|
|
seed_values = [1701, 1993, 1994, 42, 2023] |
|
selected_seed = st.selectbox("Select a seed value:", seed_values) |
|
|
|
|
|
process_button = st.button("Process") |
|
|
|
if uploaded_file is not None and process_button: |
|
try: |
|
|
|
file_extension = ''.join(Path(uploaded_file.name).suffixes) |
|
with tempfile.NamedTemporaryFile(suffix=file_extension) as tmp_file: |
|
tmp_file.write(uploaded_file.read()) |
|
tmp_file.flush() |
|
|
|
|
|
nifti_img = nib.load(tmp_file.name) |
|
data = nifti_img.get_fdata() |
|
|
|
|
|
tensor = torch.from_numpy(data).float() |
|
|
|
|
|
if tensor.ndim != 3: |
|
st.error("The uploaded NIfTI file is not a 3D volume. Please upload a valid 3D NIfTI file.") |
|
else: |
|
|
|
st.success("File successfully uploaded and read.") |
|
st.write(f"Input tensor shape: `{tensor.shape}`") |
|
st.write(f"Selected seed value: `{selected_seed}`") |
|
|
|
|
|
tensor = tensor.unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
model_name = f"GlastonburyGroup/UKBBLatent_Cardiac_20208_DiffAE3D_L128_S{selected_seed}" |
|
|
|
|
|
@st.cache_resource |
|
def load_model(model_name): |
|
hf_token = os.environ.get('HF_API_TOKEN') |
|
if hf_token is None: |
|
st.error("Hugging Face API token is not set. Please set the 'HF_API_TOKEN' environment variable.") |
|
return None |
|
try: |
|
model = AutoModel.from_pretrained( |
|
model_name, |
|
trust_remote_code=True, |
|
use_auth_token=hf_token |
|
) |
|
model.eval() |
|
return model |
|
except Exception as e: |
|
st.error(f"Failed to load model: {e}") |
|
return None |
|
|
|
with st.spinner('Loading the pre-trained model...'): |
|
model = load_model(model_name) |
|
if model is None: |
|
st.stop() |
|
|
|
|
|
device = torch.device('cpu') |
|
model = model.to(device) |
|
tensor = tensor.to(device) |
|
|
|
|
|
with st.spinner('Processing the tensor through the model...'): |
|
with torch.no_grad(): |
|
output = model.encode(tensor, use_ema=model.config.test_ema) |
|
if isinstance(output, tuple): |
|
output = output[0] |
|
output = output.squeeze(0) |
|
|
|
st.success("Processing complete.") |
|
st.write(f"Output tensor shape: `{output.shape}`") |
|
|
|
|
|
output_np = output.detach().cpu().numpy() |
|
output_list = output_np.flatten().tolist() |
|
|
|
|
|
output_data = { |
|
"latent_factors": output_list |
|
} |
|
json_str = json.dumps(output_data, indent=4) |
|
|
|
|
|
st.download_button( |
|
label="Download Output as a JSON File", |
|
data=json_str, |
|
file_name='latent_factors.json', |
|
mime='application/json' |
|
) |
|
|
|
|
|
df = pd.DataFrame({'latent_factors': output_list}) |
|
csv_str = df.to_csv(index=False) |
|
|
|
|
|
st.download_button( |
|
label="Download Output as a CSV File", |
|
data=csv_str, |
|
file_name='latent_factors.csv', |
|
mime='text/csv' |
|
) |
|
|
|
except Exception as e: |
|
st.error(f"An error occurred: {e}") |
|
elif uploaded_file is None: |
|
st.info("Awaiting file upload...") |
|
elif not process_button: |
|
st.info("Click the 'Process' button to start processing.") |