|
import json |
|
import numpy as np |
|
import spaces |
|
import gradio as gr |
|
from gradio.themes.base import Base |
|
from gradio.themes.utils import colors, fonts, sizes |
|
from typing import Iterable |
|
from tempfile import NamedTemporaryFile |
|
from PIL import Image |
|
from nudenet import NudeDetector |
|
from parallax import generate_parallax_images, generate_animation_images |
|
|
|
|
|
class Theme(Base): |
|
def __init__( |
|
self, |
|
*, |
|
primary_hue: colors.Color | str = colors.neutral, |
|
secondary_hue: colors.Color | str = colors.neutral, |
|
neutral_hue: colors.Color | str = colors.neutral, |
|
spacing_size: sizes.Size | str = sizes.spacing_md, |
|
radius_size: sizes.Size | str = sizes.radius_md, |
|
text_size: sizes.Size | str = sizes.text_md, |
|
font: fonts.Font | str | Iterable[fonts.Font | str] = (fonts.GoogleFont('Barlow'), 'ui-sans-serif', 'sans-serif'), |
|
font_mono: fonts.Font| str | Iterable[fonts.Font | str] = (fonts.GoogleFont('IBM Plex Mono'), 'ui-monospace', 'monospace',), |
|
): |
|
super().__init__( |
|
primary_hue=primary_hue, |
|
secondary_hue=secondary_hue, |
|
neutral_hue=neutral_hue, |
|
spacing_size=spacing_size, |
|
radius_size=radius_size, |
|
text_size=text_size, |
|
font=font, |
|
font_mono=font_mono, |
|
) |
|
super().set( |
|
color_accent='rgb(0 231 255 / 1)', |
|
slider_color='rgb(0 231 255 / 1)', |
|
slider_color_dark='rgb(0 231 255 / 1)', |
|
button_primary_background_fill='rgb(0 231 255 / 1)', |
|
button_primary_background_fill_hover='rgb(0 231 255 / .75)', |
|
button_primary_text_color='#ffffff', |
|
button_primary_background_fill_dark='rgb(0 231 255 / 1)', |
|
button_primary_background_fill_hover_dark='rgb(0 231 255 / .75)', |
|
button_primary_text_color_dark='#ffffff', |
|
loader_color='rgb(255 199 229 / 1)', |
|
loader_color_dark='rgb(255 199 229 / 1)' |
|
) |
|
|
|
|
|
@spaces.GPU |
|
def generate_parallax(image: np.ndarray, n: int): |
|
input_image = Image.fromarray(image).convert('RGBA') |
|
detector = NudeDetector() |
|
layers = [] |
|
indexes = [] |
|
nsfw = False |
|
|
|
with NamedTemporaryFile(delete=False, suffix='.webp') as file: |
|
input_image.save(file, format='WEBP') |
|
file.flush() |
|
|
|
for detection in detector.detect(file.name): |
|
if detection['score'] >= 0.75: |
|
nsfw = True |
|
|
|
break |
|
|
|
for index, layer in enumerate(generate_parallax_images(input_image, n)): |
|
if layer is not None: |
|
layers.append(layer) |
|
indexes.append(index) |
|
|
|
return [layers, json.dumps(indexes), nsfw] |
|
|
|
|
|
@spaces.GPU |
|
def generate_animation(image: np.ndarray): |
|
return generate_animation_images(Image.fromarray(image).convert('RGBA')) |
|
|
|
|
|
with gr.Blocks(theme=Theme()) as demo: |
|
gr.Markdown('**Parallax**') |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
parallax_input_image = gr.Image(image_mode='RGBA', label='Input') |
|
parallax_input_layers = gr.Number(value=5, precision=0, minimum=2, maximum=10, step=1, label='Layers') |
|
parallax_generate_button = gr.Button('Generate') |
|
with gr.Column(): |
|
parallax_output_gallery = gr.Gallery(label='Outputs', columns=5) |
|
parallax_output_json = gr.JSON(label='Indexes') |
|
parallax_output_nsfw = gr.Checkbox(label='NSFW') |
|
|
|
gr.Markdown('**Animation**') |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
animation_input_image = gr.Image(image_mode='RGBA', label='Input') |
|
animation_generate_button = gr.Button('Generate') |
|
with gr.Column(): |
|
animation_output_gallery = gr.Gallery(label='Outputs', columns=5) |
|
|
|
parallax_generate_button.click(fn=generate_parallax, inputs=[parallax_input_image, parallax_input_layers], outputs=[parallax_output_gallery, parallax_output_json, parallax_output_nsfw], api_name='generate_parallax') |
|
animation_generate_button.click(fn=generate_animation, inputs=animation_input_image, outputs=animation_output_gallery, api_name='generate_animation') |
|
|
|
|
|
if __name__ == '__main__': |
|
demo.launch() |
|
|