Sergidev commited on
Commit
18e3ad7
·
verified ·
1 Parent(s): bc0df63
Files changed (1) hide show
  1. app.py +199 -2
app.py CHANGED
@@ -14,9 +14,206 @@ from datetime import datetime
14
  from diffusers.models import AutoencoderKL
15
  from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
16
 
17
- # ... (keep all the imports and initial setup)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- # ... (keep all the functions like load_pipeline, parse_json_parameters, apply_json_parameters, generate, get_random_prompt)
20
 
21
  if torch.cuda.is_available():
22
  pipe = load_pipeline(MODEL)
 
14
  from diffusers.models import AutoencoderKL
15
  from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
16
 
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
20
+ DESCRIPTION = "PonyDiffusion V6 XL"
21
+ if not torch.cuda.is_available():
22
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU. </p>"
23
+ IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
24
+ HF_TOKEN = os.getenv("HF_TOKEN")
25
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
26
+ MIN_IMAGE_SIZE = int(os.getenv("MIN_IMAGE_SIZE", "512"))
27
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
28
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
29
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
30
+ OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs")
31
+
32
+ MODEL = os.getenv(
33
+ "MODEL",
34
+ "https://huggingface.co/AstraliteHeart/pony-diffusion-v6/blob/main/v6.safetensors",
35
+ )
36
+
37
+ torch.backends.cudnn.deterministic = True
38
+ torch.backends.cudnn.benchmark = False
39
+
40
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
41
+
42
+ def load_pipeline(model_name):
43
+ vae = AutoencoderKL.from_pretrained(
44
+ "madebyollin/sdxl-vae-fp16-fix",
45
+ torch_dtype=torch.float16,
46
+ )
47
+ pipeline = (
48
+ StableDiffusionXLPipeline.from_single_file
49
+ if MODEL.endswith(".safetensors")
50
+ else StableDiffusionXLPipeline.from_pretrained
51
+ )
52
+
53
+ pipe = pipeline(
54
+ model_name,
55
+ vae=vae,
56
+ torch_dtype=torch.float16,
57
+ custom_pipeline="lpw_stable_diffusion_xl",
58
+ use_safetensors=True,
59
+ add_watermarker=False,
60
+ use_auth_token=HF_TOKEN,
61
+ variant="fp16",
62
+ )
63
+
64
+ pipe.to(device)
65
+ return pipe
66
+
67
+ def parse_json_parameters(json_str):
68
+ try:
69
+ params = json.loads(json_str)
70
+ return params
71
+ except json.JSONDecodeError:
72
+ return None
73
+
74
+ def apply_json_parameters(json_str):
75
+ params = parse_json_parameters(json_str)
76
+ if params:
77
+ return (
78
+ params.get("prompt", ""),
79
+ params.get("negative_prompt", ""),
80
+ params.get("seed", 0),
81
+ params.get("width", 1024),
82
+ params.get("height", 1024),
83
+ params.get("guidance_scale", 7.0),
84
+ params.get("num_inference_steps", 30),
85
+ params.get("sampler", "DPM++ 2M SDE Karras"),
86
+ params.get("aspect_ratio", "1024 x 1024"),
87
+ params.get("use_upscaler", False),
88
+ params.get("upscaler_strength", 0.55),
89
+ params.get("upscale_by", 1.5),
90
+ )
91
+ return [gr.update()] * 12
92
+
93
+ def generate(
94
+ prompt: str,
95
+ negative_prompt: str = "",
96
+ seed: int = 0,
97
+ custom_width: int = 1024,
98
+ custom_height: int = 1024,
99
+ guidance_scale: float = 7.0,
100
+ num_inference_steps: int = 30,
101
+ sampler: str = "DPM++ 2M SDE Karras",
102
+ aspect_ratio_selector: str = "1024 x 1024",
103
+ use_upscaler: bool = False,
104
+ upscaler_strength: float = 0.55,
105
+ upscale_by: float = 1.5,
106
+ progress=gr.Progress(track_tqdm=True),
107
+ ) -> Image:
108
+ generator = utils.seed_everything(seed)
109
+
110
+ width, height = utils.aspect_ratio_handler(
111
+ aspect_ratio_selector,
112
+ custom_width,
113
+ custom_height,
114
+ )
115
+
116
+ width, height = utils.preprocess_image_dimensions(width, height)
117
+
118
+ backup_scheduler = pipe.scheduler
119
+ pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, sampler)
120
+
121
+ if use_upscaler:
122
+ upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
123
+ metadata = {
124
+ "prompt": prompt,
125
+ "negative_prompt": negative_prompt,
126
+ "resolution": f"{width} x {height}",
127
+ "guidance_scale": guidance_scale,
128
+ "num_inference_steps": num_inference_steps,
129
+ "seed": seed,
130
+ "sampler": sampler,
131
+ }
132
+
133
+ if use_upscaler:
134
+ new_width = int(width * upscale_by)
135
+ new_height = int(height * upscale_by)
136
+ metadata["use_upscaler"] = {
137
+ "upscale_method": "nearest-exact",
138
+ "upscaler_strength": upscaler_strength,
139
+ "upscale_by": upscale_by,
140
+ "new_resolution": f"{new_width} x {new_height}",
141
+ }
142
+ else:
143
+ metadata["use_upscaler"] = None
144
+ logger.info(json.dumps(metadata, indent=4))
145
+
146
+ try:
147
+ if use_upscaler:
148
+ latents = pipe(
149
+ prompt=prompt,
150
+ negative_prompt=negative_prompt,
151
+ width=width,
152
+ height=height,
153
+ guidance_scale=guidance_scale,
154
+ num_inference_steps=num_inference_steps,
155
+ generator=generator,
156
+ output_type="latent",
157
+ ).images
158
+ upscaled_latents = utils.upscale(latents, "nearest-exact", upscale_by)
159
+ images = upscaler_pipe(
160
+ prompt=prompt,
161
+ negative_prompt=negative_prompt,
162
+ image=upscaled_latents,
163
+ guidance_scale=guidance_scale,
164
+ num_inference_steps=num_inference_steps,
165
+ strength=upscaler_strength,
166
+ generator=generator,
167
+ output_type="pil",
168
+ ).images
169
+ else:
170
+ images = pipe(
171
+ prompt=prompt,
172
+ negative_prompt=negative_prompt,
173
+ width=width,
174
+ height=height,
175
+ guidance_scale=guidance_scale,
176
+ num_inference_steps=num_inference_steps,
177
+ generator=generator,
178
+ output_type="pil",
179
+ ).images
180
+
181
+ if images and IS_COLAB:
182
+ for image in images:
183
+ filepath = utils.save_image(image, metadata, OUTPUT_DIR)
184
+ logger.info(f"Image saved as {filepath} with metadata")
185
+
186
+ # Update history after generation
187
+ history = gr.get_state("history") or []
188
+ history.insert(0, {"prompt": prompt, "image": images[0], "metadata": metadata})
189
+ gr.set_state("history", history[:10]) # Keep only the last 10 entries
190
+
191
+ return images, metadata, gr.update(choices=[h["prompt"] for h in history])
192
+ except Exception as e:
193
+ logger.exception(f"An error occurred: {e}")
194
+ raise
195
+ finally:
196
+ if use_upscaler:
197
+ del upscaler_pipe
198
+ pipe.scheduler = backup_scheduler
199
+ utils.free_memory()
200
+
201
+ def get_random_prompt():
202
+ anime_characters = [
203
+ "Naruto Uzumaki", "Monkey D. Luffy", "Goku", "Eren Yeager", "Light Yagami",
204
+ "Lelouch Lamperouge", "Edward Elric", "Levi Ackerman", "Spike Spiegel",
205
+ "Sakura Haruno", "Mikasa Ackerman", "Asuka Langley Soryu", "Rem", "Megumin",
206
+ "Violet Evergarden"
207
+ ]
208
+ styles = ["pixel art", "stylized anime", "digital art", "watercolor", "sketch"]
209
+ scores = ["score_9", "score_8_up", "score_7_up"]
210
+
211
+ character = random.choice(anime_characters)
212
+ style = random.choice(styles)
213
+ score = ", ".join(random.sample(scores, k=3))
214
+
215
+ return f"{score}, {character}, {style}, show accurate"
216
 
 
217
 
218
  if torch.cuda.is_available():
219
  pipe = load_pipeline(MODEL)