sdfp12 / fn.py
aka7774's picture
Upload 11 files
91b7cdf verified
import os
from PIL import Image
import contextlib
import torch
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
from fp12 import Linear, Conv2d
pipe = None
PATH_TO_MODEL = "./animagineXLV3_v30.safetensors"
USE_FP12 = True
FP12_ONLY_ATTN = True
FP12_APPLY_LINEAR = False
FP12_APPLY_CONV = False
# ==============================================================================
# Model loading
# ==============================================================================
def free_memory():
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def to_fp12(module: torch.nn.Module):
target_modules = []
if FP12_APPLY_LINEAR:
target_modules.append((torch.nn.Linear, Linear))
if FP12_APPLY_CONV:
target_modules.append((torch.nn.Conv2d, Conv2d))
for name, mod in list(module.named_children()):
for orig_class, fp12_class in target_modules:
if isinstance(mod, orig_class):
try:
new_mod = fp12_class(mod)
except Exception as e:
print(f' -> failed: {name} {str(e)}')
continue
delattr(module, name)
del mod
setattr(module, name, new_mod)
break
def load_model_cpu(path: str):
pipe = StableDiffusionXLPipeline.from_single_file(
path,
torch_dtype=torch.float16,
safety_checker=None,
)
return pipe
def replace_fp12(pipe: DiffusionPipeline):
for name, mod in pipe.unet.named_modules():
if FP12_ONLY_ATTN and 'attn' not in name:
continue
print('[fp12] REPLACE', name)
to_fp12(mod)
return pipe
@contextlib.contextmanager
def cuda_profiler(device: str):
cuda_start = torch.cuda.Event(enable_timing=True)
cuda_end = torch.cuda.Event(enable_timing=True)
obj = {}
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats(device)
cuda_start.record()
try:
yield obj
finally:
pass
cuda_end.record()
torch.cuda.synchronize()
obj['time'] = cuda_start.elapsed_time(cuda_end)
obj['memory'] = torch.cuda.max_memory_allocated(device)
# ==============================================================================
# Generation
# ==============================================================================
def generate(pipe: DiffusionPipeline, prompt: str, negative_prompt: str, seed: int, device: str, use_amp: bool = False, guidance_scale = None, steps = None):
import contextlib
import torch.amp
context = (
torch.amp.autocast_mode.autocast if use_amp
else contextlib.nullcontext
)
with torch.no_grad(), context(device):
rng = torch.Generator(device=device)
if 0 <= seed:
rng = rng.manual_seed(seed)
latents, *_ = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=1024,
height=1024,
num_inference_steps=steps,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
generator=rng,
device=device,
return_dict=False,
output_type='latent',
)
return latents
def save_image(pipe, latents):
with torch.no_grad():
images = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
images = pipe.image_processor.postprocess(images, output_type='pil')
for i, image in enumerate(images):
#image.save(f'{i:02d}.png')
return image
def load_model(model = None, device = None):
global pipe
model = model or PATH_TO_MODEL
device = device or 'cuda:0'
pipe = load_model_cpu(model)
if USE_FP12:
pipe = replace_fp12(pipe)
free_memory()
with cuda_profiler(device) as prof:
pipe.unet = pipe.unet.to(device)
print('LOAD VRAM', prof['memory'])
print('LOAD TIME', prof['time'])
pipe.text_encoder = pipe.text_encoder.to(device)
pipe.text_encoder_2 = pipe.text_encoder_2.to(device)
if torch.cuda.is_available():
torch.cuda.synchronize(device)
def run(prompt = None, negative_prompt = None, model = None, guidance_scale = None, steps = None, seed = None, device: str = None, use_amp: bool = False):
global pipe
if not pipe:
load_model(model)
_prompt = "masterpiece, best quality, 1girl, portrait"
_negative_prompt = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name"
prompt = prompt or _prompt
negative_prompt = negative_prompt or _negative_prompt
guidance_scale = float(guidance_scale) if guidance_scale else 5.0
steps = int(steps) if steps else 20
seed = int(seed) if seed else -1
device = device or 'cuda:0'
free_memory()
with cuda_profiler(device) as prof:
latents = generate(pipe, prompt, negative_prompt, seed, device, use_amp, guidance_scale, steps)
print('UNET VRAM', prof['memory'])
print('UNET TIME', prof['time'])
#pipe.unet = pipe.unet.to('cpu')
#pipe.text_encoder = pipe.text_encoder.to('cpu')
#pipe.text_encoder_2 = pipe.text_encoder_2.to('cpu')
free_memory()
pipe.vae = pipe.vae.to(device)
pipe.vae.enable_slicing()
return save_image(pipe, latents)
def pil_to_webp(img):
buffer = io.BytesIO()
img.save(buffer, 'webp')
return buffer.getvalue()
def bin_to_base64(bin):
return base64.b64encode(bin).decode('ascii')