Joabutt commited on
Commit
0606ac4
1 Parent(s): aa64ef6

Added app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -0
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torch import autocast
4
+ from diffusers import StableDiffusionPipeline
5
+ from datasets import load_dataset
6
+ from PIL import Image
7
+ import re
8
+ import os
9
+
10
+ auth_token = os.getenv("auth_token")
11
+ model_id = "CompVis/stable-diffusion-v1-4"
12
+ device = "cpu"
13
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=auth_token, revision="fp16", torch_dtype=torch.float16)
14
+ pipe = pipe.to(device)
15
+
16
+ def infer(prompt, samples, steps, scale, seed):
17
+ generator = torch.Generator(device=device).manual_seed(seed)
18
+ images_list = pipe(
19
+ [prompt] * samples,
20
+ num_inference_steps=steps,
21
+ guidance_scale=scale,
22
+ generator=generator,
23
+ )
24
+ images = []
25
+ safe_image = Image.open(r"unsafe.png")
26
+ for i, image in enumerate(images_list["sample"]):
27
+ if(images_list["nsfw_content_detected"][i]):
28
+ images.append(safe_image)
29
+ else:
30
+ images.append(image)
31
+ return images
32
+
33
+
34
+
35
+ block = gr.Blocks()
36
+
37
+ with block:
38
+ with gr.Group():
39
+ with gr.Box():
40
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
41
+ text = gr.Textbox(
42
+ label="Enter your prompt",
43
+ show_label=False,
44
+ max_lines=1,
45
+ placeholder="Enter your prompt",
46
+ ).style(
47
+ border=(True, False, True, True),
48
+ rounded=(True, False, False, True),
49
+ container=False,
50
+ )
51
+ btn = gr.Button("Generate image").style(
52
+ margin=False,
53
+ rounded=(False, True, True, False),
54
+ )
55
+ gallery = gr.Gallery(
56
+ label="Generated images", show_label=False, elem_id="gallery"
57
+ ).style(grid=[2], height="auto")
58
+
59
+ advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
60
+
61
+ with gr.Row(elem_id="advanced-options"):
62
+ samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1)
63
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1)
64
+ scale = gr.Slider(
65
+ label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
66
+ )
67
+ seed = gr.Slider(
68
+ label="Seed",
69
+ minimum=0,
70
+ maximum=2147483647,
71
+ step=1,
72
+ randomize=True,
73
+ )
74
+ text.submit(infer, inputs=[text, samples, steps, scale, seed], outputs=gallery)
75
+ btn.click(infer, inputs=[text, samples, steps, scale, seed], outputs=gallery)
76
+ advanced_button.click(
77
+ None,
78
+ [],
79
+ text,
80
+ )
81
+
82
+ block.launch()