Abhi5ingh commited on
Commit
5b344d3
1 Parent(s): d942b6f

deployed-v1

Browse files
Files changed (3) hide show
  1. app.py +161 -0
  2. requirements.txt +7 -0
  3. sdfile.py +90 -0
app.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import numpy as np
4
+ import cv2
5
+ import streamlit as st
6
+ from PIL import Image
7
+
8
+ from sd.sdfile import PIPELINES, generate
9
+
10
+ DEFAULT_PROMPT = "belted shirt black belted portrait-collar wrap blouse with black prints"
11
+ DEAFULT_WIDTH, DEFAULT_HEIGHT = 512,512
12
+ OUTPUT_IMAGE_KEY = "output_img"
13
+ LOADED_IMAGE_KEY = "loaded_img"
14
+
15
+ def get_image(key: str) -> Optional[Image.Image]:
16
+ if key in st.session_state:
17
+ return st.session_state[key]
18
+ return None
19
+
20
+ def set_image(key:str, img: Image.Image):
21
+ st.session_state[key] = img
22
+
23
+ def prompt_and_generate_button(prefix, pipeline_name: PIPELINES, **kwargs):
24
+ prompt = st.text_area(
25
+ "Prompt",
26
+ value = DEFAULT_PROMPT,
27
+ key = f"{prefix}-prompt"
28
+ )
29
+ negative_prompt = st.text_area(
30
+ "Negative prompt",
31
+ value = "",
32
+ key =f"{prefix}-negative_prompt",
33
+ )
34
+ col1,col2 =st.columns(2)
35
+ with col1:
36
+ steps = st.slider(
37
+ "Number of inference steps",
38
+ min_value=1,
39
+ max_value=200,
40
+ value=30,
41
+ key=f"{prefix}-inference-steps",
42
+ )
43
+ with col2:
44
+ guidance_scale = st.slider(
45
+ "Guidance scale",
46
+ min_value=0.0,
47
+ max_value=20.0,
48
+ value= 7.5,
49
+ step = 0.5,
50
+ key=f"{prefix}-guidance-scale",
51
+ )
52
+ enable_cpu_offload = st.checkbox(
53
+ "Enable CPU offload if you run out of memory",
54
+ key =f"{prefix}-cpu-offload",
55
+ value= False,
56
+ )
57
+
58
+ if st.button("Generate Image", key = f"{prefix}-btn"):
59
+ with st.spinner("Generating image ..."):
60
+ image = generate(
61
+ prompt,
62
+ pipeline_name,
63
+ negative_prompt=negative_prompt,
64
+ num_inference_steps=steps,
65
+ guidance_scale=guidance_scale,
66
+ enable_cpu_offload=enable_cpu_offload,
67
+ **kwargs,
68
+ )
69
+ set_image(OUTPUT_IMAGE_KEY,image.copy())
70
+ st.image(image)
71
+ def width_and_height_sliders(prefix):
72
+ col1, col2 = st.columns(2)
73
+ with col1:
74
+ width = st.slider(
75
+ "Width",
76
+ min_value=64,
77
+ max_value=1600,
78
+ step=16,
79
+ value=512,
80
+ key=f"{prefix}-width",
81
+ )
82
+ with col2:
83
+ height = st.slider(
84
+ "Height",
85
+ min_value=64,
86
+ max_value=1600,
87
+ step=16,
88
+ value=512,
89
+ key=f"{prefix}-height",
90
+ )
91
+ return width, height
92
+
93
+ def image_uploader(prefix):
94
+ image = st.file_uploader("Image", ["jpg", "png"], key=f"{prefix}-uploader")
95
+ if image:
96
+ image = Image.open(image)
97
+ print(f"loaded input image of size ({image.width}, {image.height})")
98
+ return image
99
+
100
+ return get_image(LOADED_IMAGE_KEY)
101
+
102
+ def sketching():
103
+ image = image_uploader("Controlnet")
104
+
105
+ if not image:
106
+ return None,None
107
+ image = cv2.imread(image)
108
+ image = cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)
109
+ image_blur = cv2.GaussianBlur(image,(5,5),0)
110
+ sketch = cv2.adaptiveThreshold(image_blur, 255, cv2.ADAPTIVE_THRESH_MEAN_C,cv2.THRES_BINARY,11,2)
111
+ sketch_pil = Image.fromarray(sketch)
112
+ return sketch_pil
113
+
114
+ def txt2img_tab():
115
+ prefix = "txt2img"
116
+ width, height = width_and_height_sliders(prefix)
117
+ prompt_and_generate_button(prefix,"txt2img",width=width,height=height)
118
+
119
+ def sketching_tab():
120
+ prefix = "sketch2img"
121
+ col1,col2 = st.columns(2)
122
+ with col1:
123
+ sketch_pil = sketching()
124
+ with col2:
125
+ if sketch_pil:
126
+ controlnet_conditioning_scale = st.slider(
127
+ "Strength or dependence on the input sketch",
128
+ min_value=0.0,
129
+ max_value= 1.0,
130
+ value = 0.5,
131
+ step = 0.05,
132
+ key=f"{prefix}-controlnet_conditioning_scale",
133
+ )
134
+ prompt_and_generate_button(
135
+ prefix,
136
+ "sketch2img",
137
+ sketch_pil=sketch_pil,
138
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
139
+ )
140
+
141
+ def main():
142
+ st.set_page_config(layout="wide")
143
+ st.title("Fashion-SDX: Playground")
144
+
145
+ tab1,tab2 = st.tabs(
146
+ ["Text to image", "Sketch to image"]
147
+ )
148
+ with tab1:
149
+ txt2img_tab()
150
+ with tab2:
151
+ sketching_tab()
152
+
153
+ with st.sidebar:
154
+ st.header("Most Recent Output Image")
155
+ output_image = get_image((OUTPUT_IMAGE_KEY))
156
+ if output_image:
157
+ st.image(output_image)
158
+ else:
159
+ st.markdown("no output generated yet")
160
+ if __name__ =="__main__":
161
+ main()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ accelerate>=0.16.0
2
+ torchvision
3
+ transformers>=4.25.1
4
+ datasets
5
+ ftfy
6
+ tensorboard
7
+ Jinja2
sdfile.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import datetime
3
+ import os
4
+ import re
5
+ from typing import Literal
6
+
7
+ import streamlit as st
8
+ import torch
9
+ from diffusers import (
10
+ StableDiffusionPipeline,
11
+ StableDiffusionControlNetPipeline,
12
+ ControlNetModel,
13
+ EulerDiscreteScheduler,
14
+ DDIMScheduler,
15
+ )
16
+
17
+ PIPELINES = Literal["txt2img", "sketch2img"]
18
+
19
+ @st.cache_resource(max_entries=1)
20
+ def get_pipelines( name:PIPELINES, enable_cpu_offload = False, ) -> StableDiffusionPipeline:
21
+ pipe = None
22
+
23
+ if name == "txt2img":
24
+ pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
25
+ pipe.unet.load_attn_procs("D:\PycharmProjects\pythonProject\venv")
26
+ pipe.safety_checker = lambda images, **kwargs: (images, [False] * len(images))
27
+ elif name == "sketch2img":
28
+ controlnet = ControlNetModel.from_pretrained("Abhi5ingh/model_dresscode", torch_dtype=torch.float16)
29
+ pipe = StableDiffusionControlNetPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", controlnet = controlnet, torch_dtype = torch.float16)
30
+ pipe.unet.load_attn_procs("D:\PycharmProjects\pythonProject\venv")
31
+ pipe.safety_checker = lambda images, **kwargs: (images, [False] * len(images))
32
+
33
+ if pipe is None:
34
+ raise Exception(f"Pipeline not Found {name}")
35
+
36
+ if enable_cpu_offload:
37
+ print("Enabling cpu offloading for the given pipeline")
38
+ pipe.enable_model_cpu_offload()
39
+ else:
40
+ pipe = pipe.to("cuda")
41
+ return pipe
42
+
43
+ def generate(
44
+ prompt,
45
+ pipeline_name: PIPELINES,
46
+ sketch_pil = None,
47
+ num_inference_steps = 30,
48
+ negative_prompt = None,
49
+ width = 512,
50
+ height = 512,
51
+ guidance_scale = 7.5,
52
+ controlnet_conditioning_scale = None,
53
+ enable_cpu_offload= False):
54
+ negative_prompt = negative_prompt if negative_prompt else None
55
+ p = st.progress(0)
56
+ callback = lambda step,*_: p.progress(step/num_inference_steps)
57
+ pipe = get_pipelines(pipeline_name,enable_cpu_offload=enable_cpu_offload)
58
+ torch.cuda.empty_cache()
59
+
60
+ kwargs = dict(
61
+ prompt = prompt,
62
+ negative_prompt=negative_prompt,
63
+ num_inference_steps=num_inference_steps,
64
+ callback=callback,
65
+ guidance_scale=guidance_scale,
66
+ )
67
+ print("kwargs",kwargs)
68
+
69
+ if pipeline_name =="sketch2img" and sketch_pil:
70
+ kwargs.update(sketch_pil=sketch_pil,controlnet_conditioning_scale=controlnet_conditioning_scale)
71
+ elif pipeline_name == "txt2img":
72
+ kwargs.update(width = width, height = height)
73
+ else:
74
+ raise Exception(
75
+ f"Cannot generate image for pipeline {pipeline_name} and {prompt}")
76
+ image = images[0]
77
+
78
+ os.makedirs("outputs", exist_ok=True)
79
+
80
+ filename = (
81
+ "outputs/"
82
+ + re.sub(r"\s+", "_",prompt)[:30]
83
+ + f"_{datetime.datetime.now().timestamp()}"
84
+ )
85
+ image.save(f"{filename}.png")
86
+ with open(f"{filename}.txt", "w") as f:
87
+ f.write(f"Prompt: {prompt}\n\nNegative Prompt:{negative_prompt}"
88
+ return image
89
+
90
+ )