edgeupdate1 / src /pipeline.py
manbeast3b's picture
Update src/pipeline.py
8152a67 verified
import os
import gc
import time
import torch
import torchvision
import torch.nn as nn
from torch import Generator
from diffusers import (
FluxPipeline,
AutoencoderKL,
AutoencoderTiny,
DiffusionPipeline,
FluxTransformer2DModel
)
from diffusers.image_processor import VaeImageProcessor
from transformers import (
T5EncoderModel,
CLIPTextModel
)
from PIL import Image as img
from pipelines.models import TextToImageRequest
from model import E, D
# Environment configuration
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.cuda.set_per_process_memory_fraction(0.95)
# Constants
CKPT_ID = "black-forest-labs/FLUX.1-schnell"
# Utility functions
def clear():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
# Quantization classes
class BasicQuantization:
def __init__(self, bits=1):
self.bits = bits
self.qmin = -(2 ** (bits - 1))
self.qmax = 2 ** (bits - 1) - 1
def quantize_tensor(self, tensor):
scale = (tensor.max() - tensor.min()) / (self.qmax - self.qmin)
zero_point = self.qmin - torch.round(tensor.min() / scale)
qtensor = torch.round(tensor / scale + zero_point)
qtensor = torch.clamp(qtensor, self.qmin, self.qmax)
return (qtensor - zero_point) * scale, scale, zero_point
class ModelQuantization:
def __init__(self, model, bits=7):
self.model = model
self.quant = BasicQuantization(bits)
def quantize_model(self):
for name, module in self.model.named_modules():
if isinstance(module, nn.Linear):
if hasattr(module, 'weight'):
quantized_weight, _, _ = self.quant.quantize_tensor(module.weight)
module.weight = nn.Parameter(quantized_weight)
if hasattr(module, 'bias') and module.bias is not None:
quantized_bias, _, _ = self.quant.quantize_tensor(module.bias)
module.bias = nn.Parameter(quantized_bias)
# Pipeline loading
def load_pipeline():
"""Loads and prepares the Diffusion pipeline."""
clear()
dtype, device = torch.bfloat16, "cuda"
# Load VAE with custom encoder/decoder
vae = AutoencoderTiny.from_pretrained("manbeast3b/flux.1-schnell-vae-quant1", torch_dtype=dtype)
vae.encoder = E(16)
vae.decoder = D(16)
def lsd(p, mod, pfx):
sd = torch.load(p, map_location="cpu", weights_only=True)
f_sd = {k.strip(pfx): v for k, v in sd.items() if k.strip(pfx) in mod.state_dict() and v.size() == mod.state_dict()[k.strip(pfx)].size()}
mod.load_state_dict(f_sd, strict=False)
mod.to(dtype=torch.bfloat16)
lsd("ko.pth", vae.encoder, "encoder.")
lsd("ok.pth", vae.decoder, "decoder.")
vae.encoder.requires_grad_(False)
vae.decoder.requires_grad_(False)
# Quantize model
quantizer = ModelQuantization(vae)
quantizer.quantize_model()
text_encoder = CLIPTextModel.from_pretrained(CKPT_ID, subfolder="text_encoder", torch_dtype=dtype)
text_encoder_2 = T5EncoderModel.from_pretrained("city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=dtype)
# Transformer model
transformer_model = FluxTransformer2DModel.from_pretrained(
"/home/sandbox/.cache/huggingface/hub/models--manbeast3b--flux-schnell-transformer2d-int8-mod/snapshots/c911be0ba0d99bb717c242346c21740e7fe20ddf/",
torch_dtype=dtype,
use_safetensors=False
)
# pipeline
pipeline = DiffusionPipeline.from_pretrained(
CKPT_ID,
transformer=transformer_model,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
vae=vae,
torch_dtype=dtype
).to(device)
# Optimize memory format
for component in [pipeline.vae, pipeline.text_encoder, pipeline.text_encoder_2, pipeline.transformer]:
component.to(memory_format=torch.channels_last)
# Warm-up inference
pipeline(
prompt="modificator, drupaceous, jobbernowl, hereness",
width=1024,
height=1024,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256
)
clear()
return pipeline
@torch.inference_mode()
def infer(request: TextToImageRequest, pipeline):
"""Generates an image based on the given request."""
generator = Generator(pipeline.device).manual_seed(request.seed)
image = pipeline(
request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
output_type="pt"
).images[0]
return torchvision.transforms.functional.to_pil_image(image.to(torch.float32).mul_(2).sub_(1))