| |
| |
|
|
| import torch |
| from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline, StableCascadeCombinedPipeline |
|
|
| cas = "stabilityai/stable-cascade" |
| cas_prior = "stabilityai/stable-cascade-prior" |
|
|
| def t2i_(prompt): |
|
|
| prior = StableCascadePriorPipeline.from_pretrained(cas_prior, variant="bf16", torch_dtype=torch.bfloat16) |
| decoder = StableCascadeDecoderPipeline.from_pretrained(cas, variant="bf16", torch_dtype=torch.float16) |
| |
| prior.to("cuda") |
| decoder.to("cuda") |
| |
| |
|
|
| prior_output = prior( |
| prompt=prompt, |
| height=1024, |
| width=1024, |
| negative_prompt="", |
| guidance_scale=4.0, |
| num_images_per_prompt=1, |
| num_inference_steps=20 |
| ) |
| |
| image = decoder( |
| image_embeddings=prior_output.image_embeddings.to(torch.float16), |
| prompt=prompt, |
| negative_prompt="", |
| guidance_scale=0.0, |
| output_type="pil", |
| num_inference_steps=10 |
| ).images[0] |
|
|
| return image |
|
|
| def t2i(prompt): |
| pipe = StableCascadeCombinedPipeline.from_pretrained(cas, variant="bf16", torch_dtype=torch.bfloat16) |
| pipe.to("cuda") |
| |
| image = pipe( |
| prompt=prompt, |
| negative_prompt="", |
| num_inference_steps=10, |
| prior_num_inference_steps=20, |
| prior_guidance_scale=3.0, |
| width=1024, |
| height=1024, |
| ).images[0] |
|
|
| return image |
|
|
| if __name__ == "__main__": |
| prompt = "a girl in beijing" |
| image = t2i(prompt) |
| |
| image.save("stablecascade_output.png") |