Spaces:
Runtime error
Runtime error
import streamlit as st | |
import streamlit.components.v1 as components | |
import matplotlib.pyplot as plt | |
import pyvista as pv | |
import torch | |
import requests | |
import numpy as np | |
import numpy.typing as npt | |
from dcgan import DCGAN3D_G | |
import os | |
import pathlib | |
pv.start_xvfb() | |
STREAMLIT_STATIC_PATH = pathlib.Path(st.__path__[0]) / 'static' | |
DOWNLOADS_PATH = (STREAMLIT_STATIC_PATH / "downloads") | |
if not DOWNLOADS_PATH.is_dir(): | |
DOWNLOADS_PATH.mkdir() | |
def download_checkpoint(url: str, path: str) -> None: | |
resp = requests.get(url) | |
with open(path, 'wb') as f: | |
f.write(resp.content) | |
def generate_image(path: str, | |
image_size: int = 64, | |
z_dim: int = 512, | |
n_channels: int = 1, | |
n_features: int = 32, | |
ngpu: int = 1, | |
latent_size: int = 3) -> npt.ArrayLike: | |
netG = DCGAN3D_G(image_size, z_dim, n_channels, n_features, ngpu) | |
netG.load_state_dict(torch.load(path, map_location=torch.device('cpu'))) | |
z = torch.randn(1, z_dim, latent_size, latent_size, latent_size) | |
with torch.no_grad(): | |
X = netG(z) | |
img = 1 - (X[0, 0].numpy() + 1) / 2 | |
return img | |
def create_uniform_mesh_marching_cubes(img: npt.ArrayLike): | |
grid = pv.UniformGrid( | |
dims=img.shape, | |
spacing=(1, 1, 1), | |
origin=(0, 0, 0), | |
) | |
values = img.flatten() | |
grid.point_data['my_array'] = values | |
slices = grid.slice_orthogonal() | |
mesh = grid.contour(1, values, method='marching_cubes', rng=[1, 0], preference="points") | |
dist = np.linalg.norm(mesh.points, axis=1) | |
return slices, mesh, dist | |
def create_matplotlib_figure(img: npt.ArrayLike, midpoint: int): | |
fig, ax = plt.subplots(1, 3, figsize=(18, 6)) | |
ax[0].imshow(img[midpoint], cmap="gray", vmin=0, vmax=1) | |
ax[1].imshow(img[:, midpoint], cmap="gray", vmin=0, vmax=1) | |
ax[2].imshow(img[..., midpoint], cmap="gray", vmin=0, vmax=1) | |
for a, title in zip(ax, ["Front", "Right", "Top"]): | |
a.set_title(title, fontsize=18) | |
for a in ax: | |
a.set_axis_off() | |
return fig | |
def main(): | |
st.title("Generating Porous Media with GANs") | |
st.markdown( | |
""" | |
### Author | |
_[Lukas Mosser](https://scholar.google.com/citations?user=y0R9snMAAAAJ&hl=en&oi=ao) (2022)_ - :bird:[porestar](https://twitter.com/porestar) | |
## Description | |
This is a demo of the Generative Adversarial Network (GAN, [Goodfellow 2014](https://arxiv.org/abs/1406.2661)) trained for our publication [PorousMediaGAN](https://github.com/LukasMosser/PorousMediaGan) | |
published in Physical Review E ([Mosser et. al 2017](https://journals.aps.org/pre/abstract/10.1103/PhysRevE.96.043309)) | |
The model is a pretrained 3D Deep Convolutional GAN ([Radford 2015](https://arxiv.org/abs/1511.06434)) that generates a volumetric image of a porous medium, here a Berea sandstone, from a set of pretrained weights. | |
## Intent | |
I hope this encourages others to create interactive demos of their research for knowledge sharing and validation. | |
## The Demo | |
Slices through the 3D volume are rendered using [PyVista](https://www.pyvista.org/) and [PyThreeJS](https://pythreejs.readthedocs.io/en/stable/) | |
The model itself currently runs on the :hugging_face: [Huggingface Spaces](https://huggingface.co/spaces) instance. | |
Future migration to the :hugging_face: [Huggingface Models](https://huggingface.co/models) repository is possible. | |
### Interactive Model Parameters | |
The GAN used here in this study is fully convolutional "_Look Ma' no MLP's_": Changing the spatial extent of the latent space vector _z_ | |
allows one to generate larger synthetic images. | |
""" | |
, unsafe_allow_html=True) | |
view_width = 400 | |
view_height = 400 | |
model_fname = "berea_generator_epoch_24.pth" | |
checkpoint_url = "https://github.com/LukasMosser/PorousMediaGan/blob/master/checkpoints/berea/{0:}?raw=true".format(model_fname) | |
if not (DOWNLOADS_PATH / model_fname).exists(): | |
download_checkpoint(checkpoint_url, (DOWNLOADS_PATH / model_fname)) | |
latent_size = st.slider("Latent Space Size z", min_value=1, max_value=5, step=1) | |
img = generate_image((DOWNLOADS_PATH / model_fname), latent_size=latent_size) | |
slices, mesh, dist = create_uniform_mesh_marching_cubes(img) | |
pv.set_plot_theme("document") | |
pl = pv.Plotter(shape=(1, 1), | |
window_size=(view_width, view_height)) | |
_ = pl.add_mesh(slices, cmap="gray") | |
pl.export_html((DOWNLOADS_PATH / 'slices.html')) | |
pl = pv.Plotter(shape=(1, 1), | |
window_size=(view_width, view_height)) | |
_ = pl.add_mesh(mesh, scalars=dist) | |
pl.export_html((DOWNLOADS_PATH / 'mesh.html')) | |
st.header("2D Cross-Section of Generated Volume") | |
fig = create_matplotlib_figure(img, img.shape[0]//2) | |
st.pyplot(fig=fig) | |
HtmlFile = open((DOWNLOADS_PATH / 'slices.html'), 'r', encoding='utf-8') | |
source_code = HtmlFile.read() | |
st.header("3D Intersections") | |
components.html(source_code, width=view_width, height=view_height) | |
st.markdown("_Click and drag to spin, right click to shift._") | |
HtmlFile = open((DOWNLOADS_PATH / 'mesh.html'), 'r', encoding='utf-8') | |
source_code = HtmlFile.read() | |
st.header("3D Pore Space Mesh") | |
components.html(source_code, width=view_width, height=view_height) | |
st.markdown("_Click and drag to spin, right click to shift._") | |
st.markdown(""" | |
## Citation | |
If you use our code for your own research, we would be grateful if you cite our publication: | |
``` | |
@article{pmgan2017, | |
title={Reconstruction of three-dimensional porous media using generative adversarial neural networks}, | |
author={Mosser, Lukas and Dubrule, Olivier and Blunt, Martin J.}, | |
journal={arXiv preprint arXiv:1704.03225}, | |
year={2017} | |
}``` | |
""") | |
if __name__ == "__main__": | |
main() | |