import spaces import os import json import time import copy import torch from diffusers import AutoPipelineForText2Image, StableDiffusionPipeline,DiffusionPipeline, StableDiffusionXLPipeline, AutoencoderKL, AutoencoderTiny, UNet2DConditionModel from huggingface_hub import hf_hub_download, snapshot_download from pathlib import Path from diffusers import EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSDEScheduler from diffusers.models.attention_processor import AttnProcessor2_0 import os from cryptography.hazmat.primitives.asymmetric import rsa, padding from cryptography.hazmat.primitives import serialization, hashes from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import utils import base64 import json import jwt import glob import traceback #from onediffx import compile_pipe, save_pipe, load_pipe HF_TOKEN = os.getenv('HF_TOKEN') VAR_PUBLIC_KEY = os.getenv('PUBLIC_KEY') DATASET_ID = 'nsfwalex/checkpoint_n_lora' class AuthHelper: def load_public_key_from_file(self): public_key_bytes = VAR_PUBLIC_KEY.encode('utf-8') # Convert to bytes if it's a string public_key = serialization.load_pem_public_key( public_key_bytes, backend=default_backend() ) return public_key def __init__(self): self.public_key = self.load_public_key_from_file() # check authkey # 1. decode with public key # 2. check timestamp # 3. check current host, referer, ip it should be the same as values in jwt def decode_jwt(self, token, algorithms=["RS256"]): """ Decode and verify a JWT using a public key. :param public_key: The public key used for verification. :param token: The JWT string to decode. :param algorithms: List of acceptable algorithms (default is ["RS256"]). :return: The decoded JWT payload if verification is successful. :raises: Exception if verification fails. """ try: # Decode the JWT decoded_payload = jwt.decode( token, self.public_key, algorithms=algorithms, options={"verify_signature": True} # Explicitly enable signature verification ) return decoded_payload except Exception as e: print("Invalid token:", e) raise def check_auth(self, session, token): params = session.get("params") or {} if params.get("_skip_token_passkey", "") == "nsfwaisio_125687": return True sip = session.get("client_ip", "") shost = session.get("host", "") sreferer = session.get("refer") print(sip, shost, sreferer) jwt_data = self.decode_jwt(token) tip = jwt_data.get("ip", "") thost = jwt_data.get("host", "") treferer = jwt_data.get("referer", "") print(sip, tip, shost, thost, sreferer, treferer) if not tip or not thost or not treferer: raise Exception("invalid token") if sip == tip and shost == thost and sreferer == treferer: return True raise Exception("wrong token") class InferenceManager: def __init__(self, config_path="config.json"): cfg = {} with open(config_path, "r", encoding="utf-8") as f: cfg = json.load(f) self.cfg = cfg lora_options_path = cfg.get("loras", "") self.model_version = cfg["model_version"] self.lora_load_options = self.load_json(lora_options_path) # Load LoRA load options self.lora_models = self.load_index_file("index.json") # Load index.json self.preloaded_loras = [] # Array to store preloaded LoRAs with name and weights self.base_model_pipeline = self.load_base_model() # Load the base model self.preload_loras() # Preload LoRAs based on options def load_json(self, filepath): """Load JSON file into a dictionary.""" if os.path.exists(filepath): with open(filepath, "r", encoding="utf-8") as f: return json.load(f) return {} def load_index_file(self, index_file): """Download index.json from Hugging Face and return the file path.""" index_path = download_from_hf(index_file) if index_path: with open(index_path, "r", encoding="utf-8") as f: return json.load(f) return {} @spaces.GPU(duration=40) def compile_onediff(self): self.base_model_pipeline.to("cuda") pipe = self.base_model_pipeline # load the compiled pipe load_pipe(pipe, dir="cached_pipe") print("Start oneflow compiling...") start_compile = time.time() pipe = compile_pipe(pipe) # run once to trigger compilation image = pipe( prompt="street style, detailed, raw photo, woman, face, shot on CineStill 800T", height=512, width=512, num_inference_steps=10, output_type="pil", ).images image[0].save(f"test_image.png") compile_time = time.time() - start_compile #self.base_model_pipeline.to("cpu") # save the compiled pipe save_pipe(pipe, dir="cached_pipe") self.base_model_pipeline = pipe print(f"OneDiff compile in {compile_time}s") def load_base_model(self): """Load base model and return the pipeline.""" start = time.time() cfg = self.cfg model_version = self.model_version ckpt_dir = snapshot_download(repo_id=cfg["model_id"], local_files_only=False) if model_version == "1.5": vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.bfloat16) pipe = StableDiffusionPipeline.from_pretrained(ckpt_dir, vae=vae, torch_dtype=torch.bfloat16, use_safetensors=True) else: #vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.bfloat16) vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.bfloat16) print(ckpt_dir) pipe = DiffusionPipeline.from_pretrained( ckpt_dir, vae=vae, #unet=unet, torch_dtype=torch.bfloat16, use_safetensors=True, #variant="fp16", custom_pipeline = "lpw_stable_diffusion_xl", ) clip_skip = cfg.get("clip_skip", 1) # Adjust clip skip for XL (assumed not relevant for SD 1.5) pipe.text_encoder.config.num_hidden_layers -= (clip_skip - 1) load_time = round(time.time() - start, 2) print(f"Base model loaded in {load_time}s") return pipe def preload_loras(self): """Preload all LoRAs marked as 'preload=True' and store for later use.""" for lora_name, lora_info in self.lora_load_options.items(): try: start = time.time() # Find the corresponding LoRA in index.json lora_index_info = next((l for l in self.lora_models['lora'] if l['name'] == lora_name), None) if not lora_index_info: raise ValueError(f"LoRA {lora_name} not found in index.json.") # Check if the LoRA base model matches the current model version if self.model_version not in lora_info['base_model'] or not lora_info.get('preload', False): print(f"Skipping {lora_name} as it's not compatible with the current model version.") continue # Load LoRA weights from the specified path weight_path = download_from_hf(lora_index_info['path'], local_dir=None) if not weight_path: raise ValueError(f"Failed to download LoRA weights for {lora_name}") load_time = round(time.time() - start, 2) print(f"Downloaded {lora_name} in {load_time}s") self.base_model_pipeline.load_lora_weights( weight_path, weight_name=lora_index_info["path"], adapter_name=lora_name ) # Store the preloaded LoRA name and weight for merging later if lora_info.get("preload", False): self.preloaded_loras.append({ "name": lora_name, "weight": lora_info.get("weight", 1.0) }) load_time = round(time.time() - start, 2) print(f"Preloaded LoRA {lora_name} with weight {lora_info.get('weight', 1.0)} in {load_time}s.") except Exception as e: print(f"Lora {lora_name} not loaded, skipping... {e}") def build_pipeline_with_lora(self, lora_list, sampler=None, new_pipeline=False): """Build the pipeline with specific LoRAs, loading any that are not preloaded.""" # Deep copy the base pipeline start = time.time() if new_pipeline: temp_pipeline = copy.deepcopy(self.base_model_pipeline) else: temp_pipeline = self.base_model_pipeline copy_time = round(time.time() - start, 2) print(f"pipeline copied in {copy_time}s") # Track LoRAs to be loaded dynamically dynamic_loras = [] # Check if any LoRAs in lora_list need to be loaded dynamically for lora_name in lora_list: if not any(l['name'] == lora_name for l in self.preloaded_loras): lora_info = next((l for l in self.lora_models['lora'] if l['name'] == lora_name), None) if lora_info and self.model_version in lora_info["attr"].get("base_model", []): dynamic_loras.append({ "name": lora_name, "filename": lora_info["path"], "scale": 1.0 # Assuming default weight as 1.0 for dynamic LoRAs }) # Fuse preloaded and dynamic LoRAs all_loras = [{"name": x["name"], "scale": x["weight"], "preloaded": True} for x in self.preloaded_loras] + dynamic_loras set_lora_weights(temp_pipeline, all_loras,False) build_time = round(time.time() - start, 2) print(f"Pipeline built with LoRAs in {build_time}s.") if not sampler: sampler = self.cfg.get("sampler", "DPM2 a") # Define samplers samplers = { "Euler a": EulerAncestralDiscreteScheduler.from_config(temp_pipeline.scheduler.config), "DPM++ SDE Karras": DPMSolverSDEScheduler.from_config(temp_pipeline.scheduler.config, use_karras_sigmas=True), "DPM2 a": DPMSolverMultistepScheduler.from_config(temp_pipeline.scheduler.config) } # Set the scheduler based on the selected sampler temp_pipeline.scheduler = samplers[sampler] # Move the final pipeline to the GPU temp_pipeline return temp_pipeline def release(self, temp_pipeline): """Release the deepcopied pipeline to recycle memory.""" del temp_pipeline torch.cuda.empty_cache() print("Memory released and cache cleared.") class ModelManager: def __init__(self, model_directory): """ Initialize the ModelManager by scanning all `.model.json` files in the given directory. :param model_directory: The directory to scan for model config files (e.g., "/path/to/models"). """ self.models = {} self.model_directory = model_directory self.load_models() def load_models(self): """ Scan the model directory for `.model.json` files and initialize InferenceManager instances for each one. :param model_directory: Directory to scan for `.model.json` files. """ model_files = glob.glob(os.path.join(self.model_directory, "*.model.json")) if not model_files: print(f"No model configuration files found in {self.model_directory}") return for file_path in model_files: model_name = self.get_model_name_from_url(file_path).split(".")[0] print(f"Initializing model: {model_name} from {file_path}") try: # Initialize InferenceManager for each model self.models[model_name] = InferenceManager(config_path=file_path) except Exception as e: print(traceback.format_exc()) print(f"Failed to initialize model {model_name} from {file_path}: {e}") def get_model_name_from_url(self, url): """ Extract the model name from the config file path (filename without extension). :param url: The file path of the configuration file. :return: The model name (file name without extension). """ filename = os.path.basename(url) model_name, _ = os.path.splitext(filename) return model_name def get_model_pipeline(self, model_id, lora_list, sampler=None, new_pipeline=False): """ Build the pipeline with specific LoRAs for a model. :param model_id: The model ID (the model name extracted from the config URL). :param lora_list: List of LoRAs to be applied to the model pipeline. :param sampler: The sampler to be used for the pipeline. :param new_pipeline: Flag to indicate whether to create a new pipeline or reuse the existing one. :return: The built pipeline with LoRAs applied. """ model = self.models.get(model_id) if not model: print(f"Model {model_id} not found.") return None try: print(f"Building pipeline with LoRAs for model {model_id}...") return model.build_pipeline_with_lora(lora_list, sampler, new_pipeline) except Exception as e: print(traceback.format_exc()) print(f"Failed to build pipeline for model {model_id}: {e}") return None def release_model(self, model_id): """ Release resources and clear memory for a specific model. :param model_id: The model ID (the model name extracted from the config URL). """ model = self.models.get(model_id) if not model: print(f"Model {model_id} not found.") return try: print(f"Releasing model {model_id}...") model.release(model.base_model_pipeline) except Exception as e: print(f"Failed to release model {model_id}: {e}") # Hugging Face file download function - returns only file path def download_from_hf(filename, local_dir=None): try: file_path = hf_hub_download( filename=filename, repo_id=DATASET_ID, repo_type="dataset", revision="main", local_dir=local_dir, local_files_only=False, # Attempt to load from cache if available ) return file_path # Return file path only except Exception as e: print(f"Failed to load {filename} from Hugging Face: {str(e)}") return None # Function to load and fuse LoRAs def set_lora_weights(pipe, lorajson: list[dict], fuse=False): try: if not lorajson or not isinstance(lorajson, list): return a_list = [] w_list = [] for d in lorajson: if not d or not isinstance(d, dict) or not d["name"] or d["name"] == "None": continue k = d["name"] if not d.get("preloaded", False): start = time.time() weight_path = download_from_hf(d['filename'], local_dir=None) if weight_path: pipe.load_lora_weights(weight_path, weight_name=d['filename'], adapter_name=k) load_time = round(time.time() - start, 2) print(f"LoRA {k} loaded in {load_time}s.") a_list.append(k) w_list.append(d["scale"]) if not a_list: return start = time.time() pipe.set_adapters(a_list, adapter_weights=w_list) if fuse: pipe.fuse_lora(adapter_names=a_list, lora_scale=1.0) fuse_time = round(time.time() - start, 2) print(f"LoRAs fused in {fuse_time}s.") except Exception as e: print(f"External LoRA Error: {e}") raise Exception(f"External LoRA Error: {e}") from e