Stable-Cascade / app.py
nerualdreming's picture
Upload 10 files
38d1711 verified
import os
import random
import gradio as gr
import numpy as np
import PIL.Image
import torch
import argparse
from typing import List
from diffusers.utils import numpy_to_pil
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
from previewer.modules import Previewer
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
DESCRIPTION = "# Stable Cascade"
DESCRIPTION += "\n<p style=\"text-align: center\">Неофициальная демонстрация Stable Cascade от <a href='https://www.youtube.com/@nerual_dreming/' target='_blank'>Nerual Dreming и нейросети</a> основано на <a href='https://huggingface.co/stabilityai/stable-cascade' target='_blank'>Stable Cascade</a>, новая модель высокого разрешения для генерации изображений по текстовому запросу от Stability AI, основанная на архитектуре Würstchen - <a href='https://huggingface.co/stabilityai/stable-cascade/blob/main/LICENSE' target='_blank'>только для некоммерческого и научного использования</a></p>"
MAX_SEED = np.iinfo(np.int32).max
CACHE_EXAMPLES = False
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1536"))
USE_TORCH_COMPILE = False
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
PREVIEW_IMAGES = True
parser = argparse.ArgumentParser(description='Gradio App Control')
parser.add_argument('--share', action='store_true', help='Create a public shareable URL')
parser.add_argument('--inbrowser', action='store_true', help='Automatically launch the application in a browser')
parser.add_argument('--server_port', type=int, default=7860, help='Server port')
args = parser.parse_args()
dtype = torch.bfloat16
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
dtype = torch.float32
else:
device = "cpu"
print(f"device={device}")
if device != "cpu":
prior_pipeline = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=dtype)#.to(device)
decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype)#.to(device)
if ENABLE_CPU_OFFLOAD:
prior_pipeline.enable_model_cpu_offload()
decoder_pipeline.enable_model_cpu_offload()
else:
prior_pipeline.to(device)
decoder_pipeline.to(device)
if USE_TORCH_COMPILE:
prior_pipeline.prior = torch.compile(prior_pipeline.prior, mode="reduce-overhead", fullgraph=True)
decoder_pipeline.decoder = torch.compile(decoder_pipeline.decoder, mode="max-autotune", fullgraph=True)
if PREVIEW_IMAGES:
previewer = Previewer()
previewer_state_dict = torch.load("previewer/previewer_v1_100k.pt", map_location=torch.device('cpu'))["state_dict"]
previewer.load_state_dict(previewer_state_dict)
def callback_prior(i, t, latents):
output = previewer(latents)
output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).float().cpu().numpy())
return output
callback_steps = 1
else:
previewer = None
callback_prior = None
callback_steps = None
else:
prior_pipeline = None
decoder_pipeline = None
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
def generate(
prompt: str,
negative_prompt: str = "",
seed: int = 0,
width: int = 1024,
height: int = 1024,
prior_num_inference_steps: int = 30,
# prior_timesteps: List[float] = None,
prior_guidance_scale: float = 4.0,
decoder_num_inference_steps: int = 12,
# decoder_timesteps: List[float] = None,
decoder_guidance_scale: float = 0.0,
num_images_per_prompt: int = 2,
# profile: gr.OAuthProfile | None = None,
) -> PIL.Image.Image:
previewer.eval().requires_grad_(False).to(device).to(dtype)
prior_pipeline.to(device)
decoder_pipeline.to(device)
generator = torch.Generator().manual_seed(seed)
prior_output = prior_pipeline(
prompt=prompt,
height=height,
width=width,
num_inference_steps=prior_num_inference_steps,
timesteps=DEFAULT_STAGE_C_TIMESTEPS,
negative_prompt=negative_prompt,
guidance_scale=prior_guidance_scale,
num_images_per_prompt=num_images_per_prompt,
generator=generator,
callback=callback_prior,
callback_steps=callback_steps
)
if PREVIEW_IMAGES:
for _ in range(len(DEFAULT_STAGE_C_TIMESTEPS)):
r = next(prior_output)
if isinstance(r, list):
yield r[0]
prior_output = r
decoder_output = decoder_pipeline(
image_embeddings=prior_output.image_embeddings,
prompt=prompt,
num_inference_steps=decoder_num_inference_steps,
# timesteps=decoder_timesteps,
guidance_scale=decoder_guidance_scale,
negative_prompt=negative_prompt,
generator=generator,
output_type="pil",
).images
yield decoder_output[0]
examples = [
"An astronaut riding a green horse",
"A mecha robot in a favela by Tarsila do Amaral",
"The sprirt of a Tamagotchi wandering in the city of Los Angeles",
"A delicious feijoada ramen dish"
]
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(
value="Duplicate Space for private use",
elem_id="duplicate-button",
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
)
with gr.Group():
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Введите запрос",
container=False,
)
run_button = gr.Button("Создать", scale=0)
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Дополнительные опции", open=False):
negative_prompt = gr.Text(
label="Негативный запорос",
max_lines=1,
placeholder="Введите негативный запрос",
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Случайный seed", value=True)
with gr.Row():
width = gr.Slider(
label="Ширина",
minimum=1024,
maximum=MAX_IMAGE_SIZE,
step=512,
value=1024,
)
height = gr.Slider(
label="Высота",
minimum=1024,
maximum=MAX_IMAGE_SIZE,
step=512,
value=1024,
)
num_images_per_prompt = gr.Slider(
label="Количество изображений",
minimum=1,
maximum=2,
step=1,
value=1,
)
with gr.Row():
prior_guidance_scale = gr.Slider(
label="Prior Guidance Scale",
minimum=0,
maximum=20,
step=0.1,
value=4.0,
)
prior_num_inference_steps = gr.Slider(
label="Prior Inference Steps",
minimum=10,
maximum=30,
step=1,
value=20,
)
decoder_guidance_scale = gr.Slider(
label="Decoder Guidance Scale",
minimum=0,
maximum=0,
step=0.1,
value=0.0,
)
decoder_num_inference_steps = gr.Slider(
label="Decoder Inference Steps",
minimum=4,
maximum=12,
step=1,
value=10,
)
gr.Examples(
examples=examples,
inputs=prompt,
outputs=result,
fn=generate,
cache_examples=CACHE_EXAMPLES,
)
inputs = [
prompt,
negative_prompt,
seed,
width,
height,
prior_num_inference_steps,
# prior_timesteps,
prior_guidance_scale,
decoder_num_inference_steps,
# decoder_timesteps,
decoder_guidance_scale,
num_images_per_prompt,
]
gr.on(
triggers=[prompt.submit, negative_prompt.submit, run_button.click],
fn=randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
queue=False,
api_name=False,
).then(
fn=generate,
inputs=inputs,
outputs=result,
api_name="run",
)
with gr.Blocks(css="style.css") as demo_with_history:
with gr.Tab("App"):
demo.render()
if __name__ == "__main__":
launch_args = {
'inbrowser': args.inbrowser,
'share': args.share,
'server_port' : args.server_port,
}
demo_with_history.launch(**launch_args)