|
import os |
|
from PIL import Image |
|
import contextlib |
|
import torch |
|
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline |
|
|
|
from fp12 import Linear, Conv2d |
|
|
|
pipe = None |
|
|
|
PATH_TO_MODEL = "./animagineXLV3_v30.safetensors" |
|
USE_FP12 = True |
|
FP12_ONLY_ATTN = True |
|
FP12_APPLY_LINEAR = False |
|
FP12_APPLY_CONV = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
def free_memory(): |
|
import gc |
|
gc.collect() |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
def to_fp12(module: torch.nn.Module): |
|
target_modules = [] |
|
|
|
if FP12_APPLY_LINEAR: |
|
target_modules.append((torch.nn.Linear, Linear)) |
|
|
|
if FP12_APPLY_CONV: |
|
target_modules.append((torch.nn.Conv2d, Conv2d)) |
|
|
|
for name, mod in list(module.named_children()): |
|
for orig_class, fp12_class in target_modules: |
|
if isinstance(mod, orig_class): |
|
try: |
|
new_mod = fp12_class(mod) |
|
except Exception as e: |
|
print(f' -> failed: {name} {str(e)}') |
|
continue |
|
|
|
delattr(module, name) |
|
del mod |
|
|
|
setattr(module, name, new_mod) |
|
break |
|
|
|
|
|
def load_model_cpu(path: str): |
|
pipe = StableDiffusionXLPipeline.from_single_file( |
|
path, |
|
torch_dtype=torch.float16, |
|
safety_checker=None, |
|
) |
|
return pipe |
|
|
|
def replace_fp12(pipe: DiffusionPipeline): |
|
for name, mod in pipe.unet.named_modules(): |
|
if FP12_ONLY_ATTN and 'attn' not in name: |
|
continue |
|
print('[fp12] REPLACE', name) |
|
to_fp12(mod) |
|
return pipe |
|
|
|
|
|
@contextlib.contextmanager |
|
def cuda_profiler(device: str): |
|
cuda_start = torch.cuda.Event(enable_timing=True) |
|
cuda_end = torch.cuda.Event(enable_timing=True) |
|
|
|
obj = {} |
|
|
|
torch.cuda.synchronize() |
|
torch.cuda.reset_peak_memory_stats(device) |
|
cuda_start.record() |
|
|
|
try: |
|
yield obj |
|
finally: |
|
pass |
|
|
|
cuda_end.record() |
|
torch.cuda.synchronize() |
|
obj['time'] = cuda_start.elapsed_time(cuda_end) |
|
obj['memory'] = torch.cuda.max_memory_allocated(device) |
|
|
|
|
|
|
|
|
|
|
|
def generate(pipe: DiffusionPipeline, prompt: str, negative_prompt: str, seed: int, device: str, use_amp: bool = False, guidance_scale = None, steps = None): |
|
import contextlib |
|
import torch.amp |
|
|
|
context = ( |
|
torch.amp.autocast_mode.autocast if use_amp |
|
else contextlib.nullcontext |
|
) |
|
|
|
with torch.no_grad(), context(device): |
|
rng = torch.Generator(device=device) |
|
if 0 <= seed: |
|
rng = rng.manual_seed(seed) |
|
|
|
latents, *_ = pipe( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
width=1024, |
|
height=1024, |
|
num_inference_steps=steps, |
|
guidance_scale=guidance_scale, |
|
num_images_per_prompt=1, |
|
generator=rng, |
|
device=device, |
|
return_dict=False, |
|
output_type='latent', |
|
) |
|
|
|
return latents |
|
|
|
def save_image(pipe, latents): |
|
with torch.no_grad(): |
|
images = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0] |
|
images = pipe.image_processor.postprocess(images, output_type='pil') |
|
|
|
for i, image in enumerate(images): |
|
|
|
return image |
|
|
|
def load_model(model = None, device = None): |
|
global pipe |
|
|
|
model = model or PATH_TO_MODEL |
|
device = device or 'cuda:0' |
|
|
|
pipe = load_model_cpu(model) |
|
|
|
if USE_FP12: |
|
pipe = replace_fp12(pipe) |
|
|
|
free_memory() |
|
with cuda_profiler(device) as prof: |
|
pipe.unet = pipe.unet.to(device) |
|
print('LOAD VRAM', prof['memory']) |
|
print('LOAD TIME', prof['time']) |
|
|
|
pipe.text_encoder = pipe.text_encoder.to(device) |
|
pipe.text_encoder_2 = pipe.text_encoder_2.to(device) |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.synchronize(device) |
|
|
|
def run(prompt = None, negative_prompt = None, model = None, guidance_scale = None, steps = None, seed = None, device: str = None, use_amp: bool = False): |
|
global pipe |
|
|
|
if not pipe: |
|
load_model(model) |
|
|
|
_prompt = "masterpiece, best quality, 1girl, portrait" |
|
_negative_prompt = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name" |
|
|
|
prompt = prompt or _prompt |
|
negative_prompt = negative_prompt or _negative_prompt |
|
guidance_scale = float(guidance_scale) if guidance_scale else 5.0 |
|
steps = int(steps) if steps else 20 |
|
seed = int(seed) if seed else -1 |
|
device = device or 'cuda:0' |
|
|
|
free_memory() |
|
with cuda_profiler(device) as prof: |
|
latents = generate(pipe, prompt, negative_prompt, seed, device, use_amp, guidance_scale, steps) |
|
print('UNET VRAM', prof['memory']) |
|
print('UNET TIME', prof['time']) |
|
|
|
|
|
|
|
|
|
|
|
free_memory() |
|
pipe.vae = pipe.vae.to(device) |
|
pipe.vae.enable_slicing() |
|
return save_image(pipe, latents) |
|
|
|
def pil_to_webp(img): |
|
buffer = io.BytesIO() |
|
img.save(buffer, 'webp') |
|
|
|
return buffer.getvalue() |
|
|
|
def bin_to_base64(bin): |
|
return base64.b64encode(bin).decode('ascii') |
|
|