Prism / app.py
Masaaki Kawata
Update app.py
edb8e55
import json
import numpy as np
import spaces
import gradio as gr
from gradio.themes.base import Base
from gradio.themes.utils import colors, fonts, sizes
from typing import Iterable
from tempfile import NamedTemporaryFile
from PIL import Image
from nudenet import NudeDetector
from parallax import generate_parallax_images, generate_animation_images
class Theme(Base):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.neutral,
secondary_hue: colors.Color | str = colors.neutral,
neutral_hue: colors.Color | str = colors.neutral,
spacing_size: sizes.Size | str = sizes.spacing_md,
radius_size: sizes.Size | str = sizes.radius_md,
text_size: sizes.Size | str = sizes.text_md,
font: fonts.Font | str | Iterable[fonts.Font | str] = (fonts.GoogleFont('Barlow'), 'ui-sans-serif', 'sans-serif'),
font_mono: fonts.Font| str | Iterable[fonts.Font | str] = (fonts.GoogleFont('IBM Plex Mono'), 'ui-monospace', 'monospace',),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
spacing_size=spacing_size,
radius_size=radius_size,
text_size=text_size,
font=font,
font_mono=font_mono,
)
super().set(
color_accent='rgb(0 231 255 / 1)',
slider_color='rgb(0 231 255 / 1)',
slider_color_dark='rgb(0 231 255 / 1)',
button_primary_background_fill='rgb(0 231 255 / 1)',
button_primary_background_fill_hover='rgb(0 231 255 / .75)',
button_primary_text_color='#ffffff',
button_primary_background_fill_dark='rgb(0 231 255 / 1)',
button_primary_background_fill_hover_dark='rgb(0 231 255 / .75)',
button_primary_text_color_dark='#ffffff',
loader_color='rgb(255 199 229 / 1)',
loader_color_dark='rgb(255 199 229 / 1)'
)
@spaces.GPU
def generate_parallax(image: np.ndarray, n: int):
input_image = Image.fromarray(image).convert('RGBA')
detector = NudeDetector()
layers = []
indexes = []
nsfw = False
with NamedTemporaryFile(delete=False, suffix='.webp') as file:
input_image.save(file, format='WEBP')
file.flush()
for detection in detector.detect(file.name):
if detection['score'] >= 0.75:
nsfw = True
break
for index, layer in enumerate(generate_parallax_images(input_image, n)):
if layer is not None:
layers.append(layer)
indexes.append(index)
return [layers, json.dumps(indexes), nsfw]
@spaces.GPU
def generate_animation(image: np.ndarray):
return generate_animation_images(Image.fromarray(image).convert('RGBA'))
with gr.Blocks(theme=Theme()) as demo:
gr.Markdown('**Parallax**')
with gr.Row():
with gr.Column():
parallax_input_image = gr.Image(image_mode='RGBA', label='Input')
parallax_input_layers = gr.Number(value=5, precision=0, minimum=2, maximum=10, step=1, label='Layers')
parallax_generate_button = gr.Button('Generate')
with gr.Column():
parallax_output_gallery = gr.Gallery(label='Outputs', columns=5)
parallax_output_json = gr.JSON(label='Indexes')
parallax_output_nsfw = gr.Checkbox(label='NSFW')
gr.Markdown('**Animation**')
with gr.Row():
with gr.Column():
animation_input_image = gr.Image(image_mode='RGBA', label='Input')
animation_generate_button = gr.Button('Generate')
with gr.Column():
animation_output_gallery = gr.Gallery(label='Outputs', columns=5)
parallax_generate_button.click(fn=generate_parallax, inputs=[parallax_input_image, parallax_input_layers], outputs=[parallax_output_gallery, parallax_output_json, parallax_output_nsfw], api_name='generate_parallax')
animation_generate_button.click(fn=generate_animation, inputs=animation_input_image, outputs=animation_output_gallery, api_name='generate_animation')
if __name__ == '__main__':
demo.launch()