Pranav Pandey commited on
Commit
22a3fc5
1 Parent(s): e1a92c5

Add application file

Browse files
Files changed (4) hide show
  1. app.py +96 -2
  2. pl1.jpeg +0 -0
  3. pl2.png +0 -0
  4. requirements.txt +5 -0
app.py CHANGED
@@ -1,5 +1,99 @@
1
  import streamlit as st
 
 
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import time
3
+ import torch
4
+ from torch import autocast
5
+ from diffusers import StableDiffusionPipeline
6
+ from datasets import load_dataset
7
+ from PIL import Image
8
+ import re
9
 
10
+ st.title("Text-to-Image generation using Stable Diffusion")
11
+ st.subheader("Text Prompt")
12
+ text_prompt = st.text_area('Enter here:', height=100)
13
 
14
+ sl1, sl2, sl3, sl4 = st.columns(4)
15
+
16
+ num_samples = sl1.slider('Number of Images', 1, 4, 1)
17
+ num_steps = sl2.slider('Diffusion steps', 10, 150, 10)
18
+ scale = sl3.slider('Configuration scale', 0, 20, 10)
19
+ seed = sl4.number_input("Enter seed", 0, 25000, 47, 1)
20
+
21
+
22
+ model_id = "CompVis/stable-diffusion-v1-4"
23
+ device = "cuda"
24
+
25
+ pipe = StableDiffusionPipeline.from_pretrained(
26
+ model_id, use_auth_token=True, revision="fp16", torch_dtype=torch.float16)
27
+ pipe = pipe.to(device)
28
+ word_list_dataset = load_dataset(
29
+ "stabilityai/word-list", data_files="list.txt", use_auth_token=True)
30
+ word_list = word_list_dataset["train"]['text']
31
+
32
+
33
+ def infer(prompt, samples, steps, scale, seed):
34
+ for filter in word_list:
35
+ if re.search(rf"\b{filter}\b", prompt):
36
+ raise Exception(
37
+ "Unsafe content found. Please try again with different prompts.")
38
+
39
+ generator = torch.Generator(device=device).manual_seed(seed)
40
+ with autocast("cuda"):
41
+ images_list = pipe(
42
+ [prompt] * samples,
43
+ num_inference_steps=steps,
44
+ guidance_scale=scale,
45
+ generator=generator,
46
+ )
47
+ images = []
48
+ safe_image = Image.open(r"unsafe.png")
49
+ for i, image in enumerate(images_list["sample"]):
50
+ if (images_list["nsfw_content_detected"][i]):
51
+ images.append(safe_image)
52
+ else:
53
+ images.append(image)
54
+ return images
55
+
56
+
57
+ def check_and_infer():
58
+
59
+ if len(text_prompt) < 5:
60
+ st.write("Prompt too small, enter some more detail")
61
+ st.experimental_rerun()
62
+ else:
63
+ with st.spinner('Wait for it...'):
64
+ generated_images = infer(
65
+ text_prompt, num_samples, num_steps, scale, seed)
66
+ for image in generated_images:
67
+ st.image(image, caption=text_prompt)
68
+ st.success('Image generated!')
69
+ st.balloons()
70
+
71
+
72
+ button_clicked = st.button(
73
+ "Generate Image", on_click=check_and_infer, disabled=False)
74
+
75
+ st.markdown("""---""")
76
+
77
+ col1, col2, col3 = st.columns([1, 6, 1])
78
+
79
+ with col1:
80
+ col1.write("")
81
+
82
+ with col2:
83
+ placeholder = col2.empty()
84
+
85
+ placeholder.image("pl2.png")
86
+
87
+ with col3:
88
+ col1.write("")
89
+
90
+
91
+ for image in []:
92
+ st.image(image, caption=text_prompt)
93
+
94
+ st.markdown("""---""")
95
+
96
+ st.text("Number of Images: Number of samples(Images) to generate")
97
+ st.text("Diffusion steps: How many steps to spend generating (diffusing) your image.")
98
+ st.text("Configuration scale: Scale adjusts how close the image will be to your prompt. Higher values keep your image closer to your prompt.")
99
+ st.text("Enter seed: Seed value to use for the model.")
pl1.jpeg ADDED
pl2.png ADDED
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ diffusers
2
+ transformers
3
+ nvidia-ml-py3
4
+ --extra-index-url https://download.pytorch.org/whl/cu113 torch
5
+