Run it on low VRAM GPUs

#1
by sean-mediabox - opened

Adapted from pixart, following script would allow running it on low VRAM GPUs:

# pip install -U accelerate transformers bitsandbytes
# pip install -U git+https://github.com/huggingface/diffusers

from transformers import T5EncoderModel
from diffusers import DiffusionPipeline
import torch
import gc


def flush():
    gc.collect()
    torch.cuda.empty_cache()

def bytes_to_giga_bytes(bytes):
    return bytes / 1024 / 1024 / 1024

model_name="ptx0/pixart-900m-1024-ft"
# Loading in 8 bits needs `bitsandbytes`.
text_encoder = T5EncoderModel.from_pretrained(
    model_name,
    subfolder="text_encoder",
    load_in_8bit=True,
    device_map="auto",

)

pipe = DiffusionPipeline.from_pretrained(
    model_name,
    text_encoder=text_encoder,
    transformer=None,
    device_map="balanced"
)

with torch.no_grad():
    prompt = "A landscape photograph of a small cottage in the middle of a field of wild flowers with mountains off in the distance at sunset"
    prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt)

del text_encoder
del pipe
flush()

pipe = DiffusionPipeline.from_pretrained(
    model_name,
    text_encoder=None,
    torch_dtype=torch.float16,
).to("cuda")

latents = pipe(
    negative_prompt=None, 
    prompt_embeds=prompt_embeds,
    negative_prompt_embeds=negative_embeds,
    prompt_attention_mask=prompt_attention_mask,
    negative_prompt_attention_mask=negative_prompt_attention_mask,
    num_images_per_prompt=1,
    output_type="latent",
).images

del pipe.transformer
flush()

with torch.no_grad():
    image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
image = pipe.image_processor.postprocess(image, output_type="pil")

del pipe
flush()
print(
    f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
)

Sign up or log in to comment