ameerazam08 commited on
Commit
3081d59
β€’
1 Parent(s): 47061cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +257 -0
app.py CHANGED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install diffusers, transformers, accelerate, safetensors, huggingface_hub
2
+
3
+
4
+ import os
5
+ os.system("pip install -U peft")
6
+ import random
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+ import PIL.Image
11
+ # import spaces
12
+ import torch
13
+ from diffusers import AutoPipelineForText2Image, DPMSolverMultistepScheduler
14
+ from huggingface_hub import hf_hub_download
15
+
16
+ DESCRIPTION = """
17
+ # Res-Adapter
18
+ **Demo by [ameer azam] - [Twitter](https://twitter.com/Ameerazam18) - [GitHub](https://github.com/AMEERAZAM08)) - [Hugging Face](https://huggingface.co/ameerazam08)**
19
+ This is a demo of https://huggingface.co/jiaxiangc/res-adapter LORAs by ByteDance
20
+ """
21
+ if not torch.cuda.is_available():
22
+ DESCRIPTION += "\n<h1>Running on CPU πŸ₯Ά This demo does not work on CPU.</a> instead</h1>"
23
+
24
+ MAX_SEED = np.iinfo(np.int32).max
25
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
26
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1024"))
27
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
28
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
29
+
30
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
31
+ if torch.cuda.is_available():
32
+ pipe = AutoPipelineForText2Image.from_pretrained('Lykon/dreamshaper-xl-1-0', torch_dtype=torch.float16, variant="fp16")
33
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++")
34
+ pipe = pipe.to("cuda")
35
+
36
+
37
+ pipe.load_lora_weights(
38
+ hf_hub_download(
39
+ repo_id="jiaxiangc/res-adapter",
40
+ subfolder="sdxl-i",
41
+ filename="resolution_lora.safetensors",
42
+ ),
43
+ adapter_name="res_adapter",
44
+ )
45
+ pipe.set_adapters(["res_adapter"], adapter_weights=[1.0])
46
+
47
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
48
+ if randomize_seed:
49
+ seed = random.randint(0, MAX_SEED)
50
+ return seed
51
+
52
+
53
+ # @spaces.GPU(enable_queue=True)
54
+ def generate(
55
+ prompt: str,
56
+ negative_prompt: str = "",
57
+ prompt_2: str = "",
58
+ negative_prompt_2: str = "",
59
+ use_negative_prompt: bool = False,
60
+ use_prompt_2: bool = False,
61
+ use_negative_prompt_2: bool = False,
62
+ seed: int = 0,
63
+ width: int = 1024,
64
+ height: int = 1024,
65
+ guidance_scale_base: float = 5.0,
66
+ num_inference_steps_base: int = 20,
67
+ progress=gr.Progress(track_tqdm=True),
68
+ ) -> PIL.Image.Image:
69
+ print(f"** Generating image for: \"{prompt}\" **")
70
+ generator = torch.Generator().manual_seed(seed)
71
+
72
+ if not use_negative_prompt:
73
+ negative_prompt = None # type: ignore
74
+ if not use_prompt_2:
75
+ prompt_2 = None # type: ignore
76
+ if not use_negative_prompt_2:
77
+ negative_prompt_2 = None # type: ignore
78
+
79
+ base_image = pipe(
80
+ prompt=prompt,
81
+ negative_prompt=negative_prompt,
82
+ prompt_2=prompt_2,
83
+ negative_prompt_2=negative_prompt_2,
84
+ width=width,
85
+ height=height,
86
+ guidance_scale=guidance_scale_base,
87
+ num_inference_steps=num_inference_steps_base,
88
+ generator=generator,
89
+ output_type="pil").images[0]
90
+
91
+
92
+ res_adapt=pipe(
93
+ prompt=prompt,
94
+ negative_prompt=negative_prompt,
95
+ prompt_2=prompt_2,
96
+ negative_prompt_2=negative_prompt_2,
97
+ width=width,
98
+ height=height,
99
+ guidance_scale=guidance_scale_base,
100
+ num_inference_steps=num_inference_steps_base,
101
+ generator=generator,
102
+ output_type="pil",
103
+ ).images[0]
104
+ return [base_image,res_adapt]
105
+
106
+
107
+ examples = [
108
+ "A realistic photograph of an astronaut in a jungle, cold color palette, detailed, 8k",
109
+ "An astronaut riding a green horse",
110
+ ]
111
+
112
+ theme = gr.themes.Base(
113
+ font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
114
+ )
115
+ with gr.Blocks(css="footer{display:none !important}", theme=theme) as demo:
116
+ gr.Markdown(DESCRIPTION)
117
+ gr.DuplicateButton(
118
+ value="Duplicate Space for private use",
119
+ elem_id="duplicate-button",
120
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
121
+ )
122
+ with gr.Group():
123
+ prompt = gr.Text(
124
+ label="Prompt",
125
+ show_label=False,
126
+ max_lines=1,
127
+ container=False,
128
+ placeholder="Enter your prompt",
129
+ )
130
+ run_button = gr.Button("Generate")
131
+ # result = gr.Gallery(label="Left is Base and Right is Lora"),
132
+ with gr.Accordion("Advanced options", open=False):
133
+ with gr.Row():
134
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False)
135
+ use_prompt_2 = gr.Checkbox(label="Use prompt 2", value=False)
136
+ use_negative_prompt_2 = gr.Checkbox(label="Use negative prompt 2", value=False)
137
+ negative_prompt = gr.Text(
138
+ label="Negative prompt",
139
+ max_lines=1,
140
+ placeholder="Enter a negative prompt",
141
+ visible=False,
142
+ )
143
+ prompt_2 = gr.Text(
144
+ label="Prompt 2",
145
+ max_lines=1,
146
+ placeholder="Enter your prompt",
147
+ visible=False,
148
+ )
149
+ negative_prompt_2 = gr.Text(
150
+ label="Negative prompt 2",
151
+ max_lines=1,
152
+ placeholder="Enter a negative prompt",
153
+ visible=False,
154
+ )
155
+
156
+ seed = gr.Slider(
157
+ label="Seed",
158
+ minimum=0,
159
+ maximum=MAX_SEED,
160
+ step=1,
161
+ value=0,
162
+ )
163
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
164
+ with gr.Row():
165
+ width = gr.Slider(
166
+ label="Width",
167
+ minimum=256,
168
+ maximum=MAX_IMAGE_SIZE,
169
+ step=32,
170
+ value=1024,
171
+ )
172
+ height = gr.Slider(
173
+ label="Height",
174
+ minimum=256,
175
+ maximum=MAX_IMAGE_SIZE,
176
+ step=32,
177
+ value=1024,
178
+ )
179
+ with gr.Row():
180
+ guidance_scale_base = gr.Slider(
181
+ label="Guidance scale for base",
182
+ minimum=1,
183
+ maximum=20,
184
+ step=0.1,
185
+ value=7.5,
186
+ )
187
+ num_inference_steps_base = gr.Slider(
188
+ label="Number of inference steps for base",
189
+ minimum=10,
190
+ maximum=100,
191
+ step=1,
192
+ value=20,
193
+ )
194
+ gr.Examples(
195
+ examples=examples,
196
+ inputs=prompt,
197
+ outputs=None,
198
+ fn=generate,
199
+ cache_examples=CACHE_EXAMPLES,
200
+ )
201
+
202
+ use_negative_prompt.change(
203
+ fn=lambda x: gr.update(visible=x),
204
+ inputs=use_negative_prompt,
205
+ outputs=negative_prompt,
206
+ queue=False,
207
+ api_name=False,
208
+ )
209
+ use_prompt_2.change(
210
+ fn=lambda x: gr.update(visible=x),
211
+ inputs=use_prompt_2,
212
+ outputs=prompt_2,
213
+ queue=False,
214
+ api_name=False,
215
+ )
216
+ use_negative_prompt_2.change(
217
+ fn=lambda x: gr.update(visible=x),
218
+ inputs=use_negative_prompt_2,
219
+ outputs=negative_prompt_2,
220
+ queue=False,
221
+ api_name=False,
222
+ )
223
+ gr.on(
224
+ triggers=[
225
+ prompt.submit,
226
+ negative_prompt.submit,
227
+ prompt_2.submit,
228
+ negative_prompt_2.submit,
229
+ run_button.click,
230
+ ],
231
+ fn=randomize_seed_fn,
232
+ inputs=[seed, randomize_seed],
233
+ outputs=seed,
234
+ queue=False,
235
+ api_name=False,
236
+ ).then(
237
+ fn=generate,
238
+ inputs=[
239
+ prompt,
240
+ negative_prompt,
241
+ prompt_2,
242
+ negative_prompt_2,
243
+ use_negative_prompt,
244
+ use_prompt_2,
245
+ use_negative_prompt_2,
246
+ seed,
247
+ width,
248
+ height,
249
+ guidance_scale_base,
250
+ num_inference_steps_base,
251
+ ],
252
+ outputs=gr.Gallery(label="Left is Base and Right is Lora"),
253
+ api_name="run",
254
+ )
255
+
256
+ if __name__ == "__main__":
257
+ demo.queue(max_size=20, api_open=False).launch(show_api=False)