import os import gc import time import torch import torchvision import torch.nn as nn from torch import Generator from diffusers import ( FluxPipeline, AutoencoderKL, AutoencoderTiny, DiffusionPipeline, FluxTransformer2DModel ) from diffusers.image_processor import VaeImageProcessor from transformers import ( T5EncoderModel, CLIPTextModel ) from PIL import Image as img from pipelines.models import TextToImageRequest from model import E, D # Environment configuration os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True" torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True torch.cuda.set_per_process_memory_fraction(0.95) # Constants CKPT_ID = "black-forest-labs/FLUX.1-schnell" # Utility functions def clear(): gc.collect() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats() # Quantization classes class BasicQuantization: def __init__(self, bits=1): self.bits = bits self.qmin = -(2 ** (bits - 1)) self.qmax = 2 ** (bits - 1) - 1 def quantize_tensor(self, tensor): scale = (tensor.max() - tensor.min()) / (self.qmax - self.qmin) zero_point = self.qmin - torch.round(tensor.min() / scale) qtensor = torch.round(tensor / scale + zero_point) qtensor = torch.clamp(qtensor, self.qmin, self.qmax) return (qtensor - zero_point) * scale, scale, zero_point class ModelQuantization: def __init__(self, model, bits=7): self.model = model self.quant = BasicQuantization(bits) def quantize_model(self): for name, module in self.model.named_modules(): if isinstance(module, nn.Linear): if hasattr(module, 'weight'): quantized_weight, _, _ = self.quant.quantize_tensor(module.weight) module.weight = nn.Parameter(quantized_weight) if hasattr(module, 'bias') and module.bias is not None: quantized_bias, _, _ = self.quant.quantize_tensor(module.bias) module.bias = nn.Parameter(quantized_bias) # Pipeline loading def load_pipeline(): """Loads and prepares the Diffusion pipeline.""" clear() dtype, device = torch.bfloat16, "cuda" # Load VAE with custom encoder/decoder vae = AutoencoderTiny.from_pretrained("manbeast3b/flux.1-schnell-vae-quant1", torch_dtype=dtype) vae.encoder = E(16) vae.decoder = D(16) def lsd(p, mod, pfx): sd = torch.load(p, map_location="cpu", weights_only=True) f_sd = {k.strip(pfx): v for k, v in sd.items() if k.strip(pfx) in mod.state_dict() and v.size() == mod.state_dict()[k.strip(pfx)].size()} mod.load_state_dict(f_sd, strict=False) mod.to(dtype=torch.bfloat16) lsd("ko.pth", vae.encoder, "encoder.") lsd("ok.pth", vae.decoder, "decoder.") vae.encoder.requires_grad_(False) vae.decoder.requires_grad_(False) # Quantize model quantizer = ModelQuantization(vae) quantizer.quantize_model() text_encoder = CLIPTextModel.from_pretrained(CKPT_ID, subfolder="text_encoder", torch_dtype=dtype) text_encoder_2 = T5EncoderModel.from_pretrained("city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=dtype) # Transformer model transformer_model = FluxTransformer2DModel.from_pretrained( "/home/sandbox/.cache/huggingface/hub/models--manbeast3b--flux-schnell-transformer2d-int8-mod/snapshots/c911be0ba0d99bb717c242346c21740e7fe20ddf/", torch_dtype=dtype, use_safetensors=False ) # pipeline pipeline = DiffusionPipeline.from_pretrained( CKPT_ID, transformer=transformer_model, text_encoder=text_encoder, text_encoder_2=text_encoder_2, vae=vae, torch_dtype=dtype ).to(device) # Optimize memory format for component in [pipeline.vae, pipeline.text_encoder, pipeline.text_encoder_2, pipeline.transformer]: component.to(memory_format=torch.channels_last) # Warm-up inference pipeline( prompt="modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256 ) clear() return pipeline @torch.inference_mode() def infer(request: TextToImageRequest, pipeline): """Generates an image based on the given request.""" generator = Generator(pipeline.device).manual_seed(request.seed) image = pipeline( request.prompt, generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pt" ).images[0] return torchvision.transforms.functional.to_pil_image(image.to(torch.float32).mul_(2).sub_(1))