bhasha.dev / app.py
Dhruv Diddi
launch share updates space
f7072bd
raw
history blame
No virus
2.75 kB
import gradio as gr
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
from datasets import load_dataset
from PIL import Image
import re
import os
auth_token = os.getenv("auth_token")
model_id = "CompVis/stable-diffusion-v1-4"
device = "cpu"
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=auth_token, revision="fp16", torch_dtype=torch.float16)
pipe = pipe.to(device)
def infer(prompt, samples, steps, scale, seed):
generator = torch.Generator(device=device).manual_seed(seed)
images_list = pipe(
[prompt] * samples,
num_inference_steps=steps,
guidance_scale=scale,
generator=generator,
)
images = []
safe_image = Image.open(r"unsafe.png")
for i, image in enumerate(images_list["sample"]):
if(images_list["nsfw_content_detected"][i]):
images.append(safe_image)
else:
images.append(image)
return images
block = gr.Blocks()
with block:
with gr.Group():
with gr.Box():
with gr.Row().style(mobile_collapse=False, equal_height=True):
text = gr.Textbox(
label="Enter your prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
).style(
border=(True, False, True, True),
rounded=(True, False, False, True),
container=False,
)
btn = gr.Button("Generate image").style(
margin=False,
rounded=(False, True, True, False),
)
gallery = gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery"
).style(grid=[2], height="auto")
advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
with gr.Row(elem_id="advanced-options"):
samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1)
steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1)
scale = gr.Slider(
label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=2147483647,
step=1,
randomize=True,
)
text.submit(infer, inputs=[text, samples, steps, scale, seed], outputs=gallery)
btn.click(infer, inputs=[text, samples, steps, scale, seed], outputs=gallery)
advanced_button.click(
None,
[],
text,
)
block.launch()