ae_gen / app.py
mehdidc's picture
add ref to README in the app
8ae3c08
raw
history blame contribute delete
No virus
2.66 kB
import math
import torch
import torchvision
import gradio as gr
from PIL import Image
from cli import iterative_refinement
from viz import grid_of_images_default
models = {
"ConvAE": torch.load("convae.th", map_location="cpu"),
"Deep ConvAE": torch.load("deep_convae.th", map_location="cpu"),
"Dense K-Sparse": torch.load("fc_sparse.th", map_location="cpu"),
}
def gen(md, model_name, seed, nb_iter, nb_samples, width, height, nb_active, only_last, black_bg, binarize, binarize_threshold):
torch.manual_seed(int(seed))
bs = 64
model = models[model_name]
if model_name == "Dense K-Sparse":
model.nb_active = nb_active
samples = iterative_refinement(
model,
nb_iter=int(nb_iter),
nb_examples=int(nb_samples),
w=int(width), h=int(height), c=1,
batch_size=bs,
binarize_threshold=binarize_threshold if binarize else None,
)
if not black_bg:
samples = 1 - samples
if only_last:
s = int(math.sqrt((nb_samples)))
grid = grid_of_images_default(samples[-1].numpy(), shape=(s, s))
else:
grid = grid_of_images_default(samples.reshape((samples.shape[0]*samples.shape[1], int(height), int(width), 1)).numpy(), shape=(samples.shape[0], samples.shape[1]))
grid = (grid*255).astype("uint8")
return Image.fromarray(grid)
text = """
This interface supports generation of samples from:
- ConvAE model (from [`Digits that are not: Generating new types through deep neural nets`](https://arxiv.org/pdf/1606.04345.pdf))
- DeepConvAE model (from [here](https://tel.archives-ouvertes.fr/tel-01838272/file/75406_CHERTI_2018_diffusion.pdf), Section 10.1 with `L=3`)
- Dense K-Sparse model (from [`Out-of-class novelty generation`](https://openreview.net/forum?id=r1QXQkSYg))
These models were trained on MNIST only (digits), but were found to generate new kinds of symbols, see the references for more details.
NB: `nb_active` is only used for the Dense K-Sparse, specifying nb of activations to keep in the last layer.
Check <https://huggingface.co/spaces/mehdidc/ae_gen/blob/main/README.md> for
more details.
"""
iface = gr.Interface(
fn=gen,
inputs=[
gr.Markdown(text),
gr.Dropdown(list(models.keys()), value="Deep ConvAE"), gr.Number(value=0), gr.Number(value=25), gr.Number(value=1), gr.Number(value=28), gr.Number(value=28),gr.Slider(minimum=0,maximum=800, value=800, step=1), gr.Checkbox(value=False, label="Only show last iteration"), gr.Checkbox(value=True, label="Black background"), gr.Checkbox(value=False, label="binarize"), gr.Number(value=0.5)
],
outputs="image"
)
iface.launch()