import base64 import json import sys from collections import defaultdict from io import BytesIO from pprint import pprint from typing import Any, Dict, List import os import torch from diffusers import ( DiffusionPipeline, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, EulerAncestralDiscreteScheduler, ) from safetensors.torch import load_file from torch import autocast # https://huggingface.co/philschmid/stable-diffusion-v1-4-endpoints # https://huggingface.co/docs/inference-endpoints/guides/custom_handler # if local avoid repo url LOCAL = False PREFIX_URL = "" if not LOCAL: PREFIX_URL = "https://huggingface.co/isatis/kw/tree/main/" # set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device.type != "cuda": raise ValueError("need to run on GPU") class EndpointHandler: print(os.getcwd()) LORA_PATHS = { "hairdetailer": PREFIX_URL + "lora/hairdetailer.safetensors", "lora_leica": PREFIX_URL + "lora/lora_leica.safetensors", "epiNoiseoffset_v2": PREFIX_URL + "lora/epiNoiseoffset_v2.safetensors", "MBHU-TT2FRS": PREFIX_URL + "lora/MBHU-TT2FRS.safetensors", "ShinyOiledSkin_v20": PREFIX_URL + "lora/ShinyOiledSkin_v20-LoRA.safetensors", "polyhedron_new_skin_v1.1": PREFIX_URL + "lora/polyhedron_new_skin_v1.1.safetensors", "detailed_eye-10": PREFIX_URL + "lora/detailed_eye-10.safetensors", "add_detail": PREFIX_URL + "lora/add_detail.safetensors", "MuscleGirl_v1": PREFIX_URL + "lora/MuscleGirl_v1.safetensors", } TEXTUAL_INVERSION = [ { "weight_name": PREFIX_URL + "embeddings/EasyNegative.safetensors", "token": "easynegative", }, { "weight_name": PREFIX_URL + "embeddings/EasyNegative.safetensors", "token": "EasyNegative", }, {"weight_name": PREFIX_URL + "embeddings/badhandv4.pt", "token": "badhandv4"}, { "weight_name": PREFIX_URL + "embeddings/bad-artist-anime.pt", "token": "bad-artist-anime", }, {"weight_name": PREFIX_URL + "embeddings/NegfeetV2.pt", "token": "NegfeetV2"}, { "weight_name": PREFIX_URL + "embeddings/ng_deepnegative_v1_75t.pt", "token": "ng_deepnegative_v1_75t", }, { "weight_name": PREFIX_URL + "embeddings/ng_deepnegative_v1_75t.pt", "token": "NG_DeepNegative_V1_75T", }, { "weight_name": PREFIX_URL + "embeddings/bad-hands-5.pt", "token": "bad-hands-5", }, ] def __init__(self, path="."): # load the optimized model self.pipe = DiffusionPipeline.from_pretrained( path, custom_pipeline="lpw_stable_diffusion", # avoid 77 token limit torch_dtype=torch.float16, # accelerate render ) self.pipe = self.pipe.to(device) # DPM++ 2M SDE Karras # increase step to avoid high contrast num_inference_steps=30 self.pipe.scheduler = DPMSolverMultistepScheduler.from_config( self.pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++", ) # Mode boulardus self.pipe.safety_checker = None # Load negative embeddings to avoid bad hands, etc self.load_embeddings() # Load default Lora models self.pipe = self.load_selected_loras( [ ("polyhedron_new_skin_v1.1", 0.35), # nice Skin ("detailed_eye-10", 0.3), # nice eyes ("add_detail", 0.4), # detailed pictures ("MuscleGirl_v1", 0.3), # shape persons ], ) # boosts performance by another 20% self.pipe.enable_xformers_memory_efficient_attention() self.pipe.enable_attention_slicing() def load_lora(self, pipeline, lora_path, lora_weight=0.5): state_dict = load_file(lora_path) LORA_PREFIX_UNET = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" alpha = lora_weight visited = [] for key in state_dict: state_dict[key] = state_dict[key].to(device) # directly update weight in diffusers model for key in state_dict: # as we have set the alpha beforehand, so just skip if ".alpha" in key or key in visited: continue if "text" in key: layer_infos = ( key.split(".")[0] .split(LORA_PREFIX_TEXT_ENCODER + "_")[-1] .split("_") ) curr_layer = pipeline.text_encoder else: layer_infos = ( key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") ) curr_layer = pipeline.unet # find the target layer temp_name = layer_infos.pop(0) while len(layer_infos) > -1: try: curr_layer = curr_layer.__getattr__(temp_name) if len(layer_infos) > 0: temp_name = layer_infos.pop(0) elif len(layer_infos) == 0: break except Exception: if len(temp_name) > 0: temp_name += "_" + layer_infos.pop(0) else: temp_name = layer_infos.pop(0) # org_forward(x) + lora_up(lora_down(x)) * multiplier pair_keys = [] if "lora_down" in key: pair_keys.append(key.replace("lora_down", "lora_up")) pair_keys.append(key) else: pair_keys.append(key) pair_keys.append(key.replace("lora_up", "lora_down")) # update weight if len(state_dict[pair_keys[0]].shape) == 4: weight_up = ( state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) ) weight_down = ( state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) ) curr_layer.weight.data += alpha * torch.mm( weight_up, weight_down ).unsqueeze(2).unsqueeze(3) else: weight_up = state_dict[pair_keys[0]].to(torch.float32) weight_down = state_dict[pair_keys[1]].to(torch.float32) curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down) # update visited list for item in pair_keys: visited.append(item) return pipeline def load_embeddings(self): """Load textual inversions, avoid bad prompts""" for model in EndpointHandler.TEXTUAL_INVERSION: self.pipe.load_textual_inversion( ".", weight_name=model["weight_name"], token=model["token"] ) def load_selected_loras(self, selections): """Load Loras models, can lead to marvelous creations""" for model_name, weight in selections: lora_path = EndpointHandler.LORA_PATHS[model_name] self.pipe = self.load_lora( pipeline=self.pipe, lora_path=lora_path, lora_weight=weight ) return self.pipe def __call__(self, data: Any) -> List[List[Dict[str, float]]]: """ Args: data (:obj:): includes the input data and the parameters for the inference. Return: A :obj:`dict`:. base64 encoded image """ global device # Which Lora do we load ? # selected_models = [ # ("ShinyOiledSkin_v20", 0.3), # ("MBHU-TT2FRS", 0.5), # ("hairdetailer", 0.5), # ("lora_leica", 0.5), # ("epiNoiseoffset_v2", 0.5), # ] # 1. Verify input arguments required_fields = [ "prompt", "negative_prompt", "width", "num_inference_steps", "height", "seed", "guidance_scale", ] missing_fields = [field for field in required_fields if field not in data] if missing_fields: return { "flag": "error", "message": f"Missing fields: {', '.join(missing_fields)}", } # Now extract the fields prompt = data["prompt"] negative_prompt = data["negative_prompt"] loras_model = data.pop("loras_model", None) seed = data["seed"] width = data["width"] num_inference_steps = data["num_inference_steps"] height = data["height"] guidance_scale = data["guidance_scale"] # USe this to add automatically some negative prompts forced_negative = ( negative_prompt + """easynegative, badhandv4, bad-artist-anime, NegfeetV2, ng_deepnegative_v1_75t, bad-hands-5 """ ) # Set the generator seed if provided generator = torch.Generator(device="cuda").manual_seed(seed) if seed else None # Load the provided Lora models if loras_model: self.pipe = self.load_selected_loras(loras_model) try: # 2. Process with autocast(device.type): image = self.pipe.text2img( prompt=prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, height=height, width=width, negative_prompt=forced_negative, generator=generator, max_embeddings_multiples=5, ).images[0] # encode image as base 64 buffered = BytesIO() image.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()) # Return the success response return {"flag": "success", "image": img_str.decode()} except Exception as e: # Handle any other exceptions and return an error response return {"flag": "error", "message": str(e)}