File size: 5,828 Bytes
91b7cdf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import os
from PIL import Image
import contextlib
import torch
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
from fp12 import Linear, Conv2d
pipe = None
PATH_TO_MODEL = "./animagineXLV3_v30.safetensors"
USE_FP12 = True
FP12_ONLY_ATTN = True
FP12_APPLY_LINEAR = False
FP12_APPLY_CONV = False
# ==============================================================================
# Model loading
# ==============================================================================
def free_memory():
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def to_fp12(module: torch.nn.Module):
target_modules = []
if FP12_APPLY_LINEAR:
target_modules.append((torch.nn.Linear, Linear))
if FP12_APPLY_CONV:
target_modules.append((torch.nn.Conv2d, Conv2d))
for name, mod in list(module.named_children()):
for orig_class, fp12_class in target_modules:
if isinstance(mod, orig_class):
try:
new_mod = fp12_class(mod)
except Exception as e:
print(f' -> failed: {name} {str(e)}')
continue
delattr(module, name)
del mod
setattr(module, name, new_mod)
break
def load_model_cpu(path: str):
pipe = StableDiffusionXLPipeline.from_single_file(
path,
torch_dtype=torch.float16,
safety_checker=None,
)
return pipe
def replace_fp12(pipe: DiffusionPipeline):
for name, mod in pipe.unet.named_modules():
if FP12_ONLY_ATTN and 'attn' not in name:
continue
print('[fp12] REPLACE', name)
to_fp12(mod)
return pipe
@contextlib.contextmanager
def cuda_profiler(device: str):
cuda_start = torch.cuda.Event(enable_timing=True)
cuda_end = torch.cuda.Event(enable_timing=True)
obj = {}
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats(device)
cuda_start.record()
try:
yield obj
finally:
pass
cuda_end.record()
torch.cuda.synchronize()
obj['time'] = cuda_start.elapsed_time(cuda_end)
obj['memory'] = torch.cuda.max_memory_allocated(device)
# ==============================================================================
# Generation
# ==============================================================================
def generate(pipe: DiffusionPipeline, prompt: str, negative_prompt: str, seed: int, device: str, use_amp: bool = False, guidance_scale = None, steps = None):
import contextlib
import torch.amp
context = (
torch.amp.autocast_mode.autocast if use_amp
else contextlib.nullcontext
)
with torch.no_grad(), context(device):
rng = torch.Generator(device=device)
if 0 <= seed:
rng = rng.manual_seed(seed)
latents, *_ = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=1024,
height=1024,
num_inference_steps=steps,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
generator=rng,
device=device,
return_dict=False,
output_type='latent',
)
return latents
def save_image(pipe, latents):
with torch.no_grad():
images = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
images = pipe.image_processor.postprocess(images, output_type='pil')
for i, image in enumerate(images):
#image.save(f'{i:02d}.png')
return image
def load_model(model = None, device = None):
global pipe
model = model or PATH_TO_MODEL
device = device or 'cuda:0'
pipe = load_model_cpu(model)
if USE_FP12:
pipe = replace_fp12(pipe)
free_memory()
with cuda_profiler(device) as prof:
pipe.unet = pipe.unet.to(device)
print('LOAD VRAM', prof['memory'])
print('LOAD TIME', prof['time'])
pipe.text_encoder = pipe.text_encoder.to(device)
pipe.text_encoder_2 = pipe.text_encoder_2.to(device)
if torch.cuda.is_available():
torch.cuda.synchronize(device)
def run(prompt = None, negative_prompt = None, model = None, guidance_scale = None, steps = None, seed = None, device: str = None, use_amp: bool = False):
global pipe
if not pipe:
load_model(model)
_prompt = "masterpiece, best quality, 1girl, portrait"
_negative_prompt = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name"
prompt = prompt or _prompt
negative_prompt = negative_prompt or _negative_prompt
guidance_scale = float(guidance_scale) if guidance_scale else 5.0
steps = int(steps) if steps else 20
seed = int(seed) if seed else -1
device = device or 'cuda:0'
free_memory()
with cuda_profiler(device) as prof:
latents = generate(pipe, prompt, negative_prompt, seed, device, use_amp, guidance_scale, steps)
print('UNET VRAM', prof['memory'])
print('UNET TIME', prof['time'])
#pipe.unet = pipe.unet.to('cpu')
#pipe.text_encoder = pipe.text_encoder.to('cpu')
#pipe.text_encoder_2 = pipe.text_encoder_2.to('cpu')
free_memory()
pipe.vae = pipe.vae.to(device)
pipe.vae.enable_slicing()
return save_image(pipe, latents)
def pil_to_webp(img):
buffer = io.BytesIO()
img.save(buffer, 'webp')
return buffer.getvalue()
def bin_to_base64(bin):
return base64.b64encode(bin).decode('ascii')
|