import os from PIL import Image import contextlib import torch from diffusers import DiffusionPipeline, StableDiffusionXLPipeline from fp12 import Linear, Conv2d pipe = None PATH_TO_MODEL = "./animagineXLV3_v30.safetensors" USE_FP12 = True FP12_ONLY_ATTN = True FP12_APPLY_LINEAR = False FP12_APPLY_CONV = False # ============================================================================== # Model loading # ============================================================================== def free_memory(): import gc gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def to_fp12(module: torch.nn.Module): target_modules = [] if FP12_APPLY_LINEAR: target_modules.append((torch.nn.Linear, Linear)) if FP12_APPLY_CONV: target_modules.append((torch.nn.Conv2d, Conv2d)) for name, mod in list(module.named_children()): for orig_class, fp12_class in target_modules: if isinstance(mod, orig_class): try: new_mod = fp12_class(mod) except Exception as e: print(f' -> failed: {name} {str(e)}') continue delattr(module, name) del mod setattr(module, name, new_mod) break def load_model_cpu(path: str): pipe = StableDiffusionXLPipeline.from_single_file( path, torch_dtype=torch.float16, safety_checker=None, ) return pipe def replace_fp12(pipe: DiffusionPipeline): for name, mod in pipe.unet.named_modules(): if FP12_ONLY_ATTN and 'attn' not in name: continue print('[fp12] REPLACE', name) to_fp12(mod) return pipe @contextlib.contextmanager def cuda_profiler(device: str): cuda_start = torch.cuda.Event(enable_timing=True) cuda_end = torch.cuda.Event(enable_timing=True) obj = {} torch.cuda.synchronize() torch.cuda.reset_peak_memory_stats(device) cuda_start.record() try: yield obj finally: pass cuda_end.record() torch.cuda.synchronize() obj['time'] = cuda_start.elapsed_time(cuda_end) obj['memory'] = torch.cuda.max_memory_allocated(device) # ============================================================================== # Generation # ============================================================================== def generate(pipe: DiffusionPipeline, prompt: str, negative_prompt: str, seed: int, device: str, use_amp: bool = False, guidance_scale = None, steps = None): import contextlib import torch.amp context = ( torch.amp.autocast_mode.autocast if use_amp else contextlib.nullcontext ) with torch.no_grad(), context(device): rng = torch.Generator(device=device) if 0 <= seed: rng = rng.manual_seed(seed) latents, *_ = pipe( prompt=prompt, negative_prompt=negative_prompt, width=1024, height=1024, num_inference_steps=steps, guidance_scale=guidance_scale, num_images_per_prompt=1, generator=rng, device=device, return_dict=False, output_type='latent', ) return latents def save_image(pipe, latents): with torch.no_grad(): images = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0] images = pipe.image_processor.postprocess(images, output_type='pil') for i, image in enumerate(images): #image.save(f'{i:02d}.png') return image def load_model(model = None, device = None): global pipe model = model or PATH_TO_MODEL device = device or 'cuda:0' pipe = load_model_cpu(model) if USE_FP12: pipe = replace_fp12(pipe) free_memory() with cuda_profiler(device) as prof: pipe.unet = pipe.unet.to(device) print('LOAD VRAM', prof['memory']) print('LOAD TIME', prof['time']) pipe.text_encoder = pipe.text_encoder.to(device) pipe.text_encoder_2 = pipe.text_encoder_2.to(device) if torch.cuda.is_available(): torch.cuda.synchronize(device) def run(prompt = None, negative_prompt = None, model = None, guidance_scale = None, steps = None, seed = None, device: str = None, use_amp: bool = False): global pipe if not pipe: load_model(model) _prompt = "masterpiece, best quality, 1girl, portrait" _negative_prompt = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name" prompt = prompt or _prompt negative_prompt = negative_prompt or _negative_prompt guidance_scale = float(guidance_scale) if guidance_scale else 5.0 steps = int(steps) if steps else 20 seed = int(seed) if seed else -1 device = device or 'cuda:0' free_memory() with cuda_profiler(device) as prof: latents = generate(pipe, prompt, negative_prompt, seed, device, use_amp, guidance_scale, steps) print('UNET VRAM', prof['memory']) print('UNET TIME', prof['time']) #pipe.unet = pipe.unet.to('cpu') #pipe.text_encoder = pipe.text_encoder.to('cpu') #pipe.text_encoder_2 = pipe.text_encoder_2.to('cpu') free_memory() pipe.vae = pipe.vae.to(device) pipe.vae.enable_slicing() return save_image(pipe, latents) def pil_to_webp(img): buffer = io.BytesIO() img.save(buffer, 'webp') return buffer.getvalue() def bin_to_base64(bin): return base64.b64encode(bin).decode('ascii')