import streamlit as st import streamlit.components.v1 as components import matplotlib.pyplot as plt import pyvista as pv import torch import requests import numpy as np import numpy.typing as npt from dcgan import DCGAN3D_G import pathlib import time pv.start_xvfb() class DummyWriteable(object): def __init__(self): self.html = None def write(self, html): self.html = html STREAMLIT_STATIC_PATH = pathlib.Path(st.__path__[0]) / 'static' DOWNLOADS_PATH = (STREAMLIT_STATIC_PATH / "downloads") if not DOWNLOADS_PATH.is_dir(): DOWNLOADS_PATH.mkdir() def download_checkpoint(url: str, path: str) -> None: resp = requests.get(url) with open(path, 'wb') as f: f.write(resp.content) @st.cache(persist=True, allow_output_mutation=True) def load_model(path: str, image_size: int = 64, z_dim: int = 512, n_channels: int = 1, n_features: int = 32, ngpu: int = 1,) -> torch.nn.Module: netG = DCGAN3D_G(image_size, z_dim, n_channels, n_features, ngpu) netG.load_state_dict(torch.load(path, map_location=torch.device('cpu'))) return netG @st.cache() def generate_image(netG: torch.nn.Module, z_dim: int = 512, latent_size: int = 3) -> npt.ArrayLike: z = torch.randn(1, z_dim, latent_size, latent_size, latent_size) with torch.no_grad(): X = netG(z) img = 1 - (X[0, 0].numpy() + 1) / 2 return img def create_uniform_mesh_marching_cubes(img: npt.ArrayLike): grid = pv.UniformGrid( dims=img.shape, spacing=(1, 1, 1), origin=(0, 0, 0), ) values = img.flatten() grid.point_data['my_array'] = values slices = grid.slice_orthogonal() mesh = grid.contour(1, values, method='marching_cubes', rng=[1, 0], preference="points") dist = np.linalg.norm(mesh.points, axis=1) return slices, mesh, dist def create_matplotlib_figure(img: npt.ArrayLike, midpoint: int): fig, ax = plt.subplots(1, 3, figsize=(18, 6)) ax[0].imshow(img[midpoint], cmap="gray", vmin=0, vmax=1) ax[1].imshow(img[:, midpoint], cmap="gray", vmin=0, vmax=1) ax[2].imshow(img[..., midpoint], cmap="gray", vmin=0, vmax=1) for a, title in zip(ax, ["Front", "Right", "Top"]): a.set_title(title, fontsize=18) for a in ax: a.set_axis_off() return fig def main(): st.title("Generating Porous Media with GANs") st.markdown( """ ### Author _[Lukas Mosser](https://scholar.google.com/citations?user=y0R9snMAAAAJ&hl=en&oi=ao) (2022)_ - :bird:[porestar](https://twitter.com/porestar) ## Description This is a demo of the Generative Adversarial Network (GAN, [Goodfellow 2014](https://arxiv.org/abs/1406.2661)) trained for our publication [PorousMediaGAN](https://github.com/LukasMosser/PorousMediaGan) published in Physical Review E ([Mosser et. al 2017](https://journals.aps.org/pre/abstract/10.1103/PhysRevE.96.043309)) The model is a pretrained 3D Deep Convolutional GAN ([Radford 2015](https://arxiv.org/abs/1511.06434)) that generates a volumetric image of a porous medium, here a Berea sandstone, from a set of pretrained weights. ## Intent I hope this encourages others to create interactive demos of their research for knowledge sharing and validation. ## The Demo Slices through the 3D volume are rendered using [PyVista](https://www.pyvista.org/) and [PyThreeJS](https://pythreejs.readthedocs.io/en/stable/) The model itself currently runs on the :hugging_face: [Huggingface Spaces](https://huggingface.co/spaces) instance. Future migration to the :hugging_face: [Huggingface Models](https://huggingface.co/models) repository is possible. ### Interactive Model Parameters The GAN used here in this study is fully convolutional "_Look Ma' no MLP's_": Changing the spatial extent of the latent space vector _z_ allows one to generate larger synthetic images. """ , unsafe_allow_html=True) view_width = 400 view_height = 400 model_fname = "berea_generator_epoch_24.pth" checkpoint_url = "https://github.com/LukasMosser/PorousMediaGan/blob/master/checkpoints/berea/{0:}?raw=true".format(model_fname) if not (DOWNLOADS_PATH / model_fname).exists(): download_checkpoint(checkpoint_url, (DOWNLOADS_PATH / model_fname)) netG = load_model((DOWNLOADS_PATH / model_fname)) latent_size = st.slider("Latent Space Size z", min_value=1, max_value=5, step=1) img = generate_image(netG, latent_size=latent_size) slices, mesh, dist = create_uniform_mesh_marching_cubes(img) pv.set_plot_theme("document") pl1 = pv.Plotter(shape=(1, 1), window_size=(view_width, view_height)) _ = pl1.add_mesh(slices, cmap="gray") slices_html = DummyWriteable() try: pl1.export_html(slices_html) except RuntimeError as e: print(e) pl2 = pv.Plotter(shape=(1, 1), window_size=(view_width, view_height)) _ = pl2.add_mesh(mesh, scalars=dist) mesh_html = DummyWriteable() try: pl2.export_html(mesh_html) except RuntimeError as e: print(e) st.header("2D Cross-Section of Generated Volume") fig = create_matplotlib_figure(img, img.shape[0]//2) st.pyplot(fig=fig) st.header("3D Intersections") components.html(slices_html.html, width=view_width, height=view_height) st.markdown("_Click and drag to spin, right click to shift._") st.header("3D Pore Space Mesh") components.html(mesh_html.html, width=view_width, height=view_height) st.markdown("_Click and drag to spin, right click to shift._") st.markdown(""" ## Citation If you use our code for your own research, we would be grateful if you cite our publication: ``` @article{pmgan2017, title={Reconstruction of three-dimensional porous media using generative adversarial neural networks}, author={Mosser, Lukas and Dubrule, Olivier and Blunt, Martin J.}, journal={arXiv preprint arXiv:1704.03225}, year={2017} }``` """) if __name__ == "__main__": main()