Spaces:
				
			
			
	
			
			
		Build error
		
	
	
	
			
			
	
	
	
	
		
		
		Build error
		
	add contronet canny
Browse files- app-controlnet.py +296 -0
 - canny_gpu.py +44 -0
 - controlnet/index.html +414 -0
 - controlnet/tailwind.config.js +0 -0
 - latent_consistency_controlnet.py +1094 -0
 - requirements.txt +2 -1
 
    	
        app-controlnet.py
    ADDED
    
    | 
         @@ -0,0 +1,296 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import asyncio
         
     | 
| 2 | 
         
            +
            import json
         
     | 
| 3 | 
         
            +
            import logging
         
     | 
| 4 | 
         
            +
            import traceback
         
     | 
| 5 | 
         
            +
            from pydantic import BaseModel
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect
         
     | 
| 8 | 
         
            +
            from fastapi.middleware.cors import CORSMiddleware
         
     | 
| 9 | 
         
            +
            from fastapi.responses import StreamingResponse, JSONResponse
         
     | 
| 10 | 
         
            +
            from fastapi.staticfiles import StaticFiles
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            from diffusers import AutoencoderTiny, ControlNetModel
         
     | 
| 13 | 
         
            +
            from latent_consistency_controlnet import LatentConsistencyModelPipeline_controlnet
         
     | 
| 14 | 
         
            +
            from compel import Compel
         
     | 
| 15 | 
         
            +
            import torch
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            from canny_gpu import SobelOperator 
         
     | 
| 18 | 
         
            +
            # from controlnet_aux import OpenposeDetector
         
     | 
| 19 | 
         
            +
            # import cv2
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            try:
         
     | 
| 22 | 
         
            +
                import intel_extension_for_pytorch as ipex
         
     | 
| 23 | 
         
            +
            except:
         
     | 
| 24 | 
         
            +
                pass
         
     | 
| 25 | 
         
            +
            from PIL import Image
         
     | 
| 26 | 
         
            +
            import numpy as np
         
     | 
| 27 | 
         
            +
            import gradio as gr
         
     | 
| 28 | 
         
            +
            import io
         
     | 
| 29 | 
         
            +
            import uuid
         
     | 
| 30 | 
         
            +
            import os
         
     | 
| 31 | 
         
            +
            import time
         
     | 
| 32 | 
         
            +
            import psutil
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            MAX_QUEUE_SIZE = int(os.environ.get("MAX_QUEUE_SIZE", 0))
         
     | 
| 36 | 
         
            +
            TIMEOUT = float(os.environ.get("TIMEOUT", 0))
         
     | 
| 37 | 
         
            +
            SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
         
     | 
| 38 | 
         
            +
            WIDTH = 512
         
     | 
| 39 | 
         
            +
            HEIGHT = 512
         
     | 
| 40 | 
         
            +
            # disable tiny autoencoder for better quality speed tradeoff
         
     | 
| 41 | 
         
            +
            USE_TINY_AUTOENCODER = True
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            # check if MPS is available OSX only M1/M2/M3 chips
         
     | 
| 44 | 
         
            +
            mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
         
     | 
| 45 | 
         
            +
            xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
         
     | 
| 46 | 
         
            +
            device = torch.device(
         
     | 
| 47 | 
         
            +
                "cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu"
         
     | 
| 48 | 
         
            +
            )
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            # change to torch.float16 to save GPU memory
         
     | 
| 51 | 
         
            +
            torch_dtype = torch.float16
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
            print(f"TIMEOUT: {TIMEOUT}")
         
     | 
| 54 | 
         
            +
            print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
         
     | 
| 55 | 
         
            +
            print(f"MAX_QUEUE_SIZE: {MAX_QUEUE_SIZE}")
         
     | 
| 56 | 
         
            +
            print(f"device: {device}")
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            if mps_available:
         
     | 
| 59 | 
         
            +
                device = torch.device("mps")
         
     | 
| 60 | 
         
            +
                device = "cpu"
         
     | 
| 61 | 
         
            +
                torch_dtype = torch.float32
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            controlnet_canny = ControlNetModel.from_pretrained(
         
     | 
| 64 | 
         
            +
                "lllyasviel/control_v11p_sd15_canny", torch_dtype=torch_dtype
         
     | 
| 65 | 
         
            +
            ).to(device)
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            canny_torch = SobelOperator(device=device)
         
     | 
| 68 | 
         
            +
            # controlnet_pose = ControlNetModel.from_pretrained(
         
     | 
| 69 | 
         
            +
            #     "lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch_dtype
         
     | 
| 70 | 
         
            +
            # ).to(device)
         
     | 
| 71 | 
         
            +
            # controlnet_depth = ControlNetModel.from_pretrained(
         
     | 
| 72 | 
         
            +
            #     "lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch_dtype
         
     | 
| 73 | 
         
            +
            # ).to(device)
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
            # pose_processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
            if SAFETY_CHECKER == "True":
         
     | 
| 79 | 
         
            +
                pipe = LatentConsistencyModelPipeline_controlnet.from_pretrained(
         
     | 
| 80 | 
         
            +
                    "SimianLuo/LCM_Dreamshaper_v7",
         
     | 
| 81 | 
         
            +
                    controlnet=controlnet_canny,
         
     | 
| 82 | 
         
            +
                    scheduler=None,
         
     | 
| 83 | 
         
            +
                )
         
     | 
| 84 | 
         
            +
            else:
         
     | 
| 85 | 
         
            +
                pipe = LatentConsistencyModelPipeline_controlnet.from_pretrained(
         
     | 
| 86 | 
         
            +
                    "SimianLuo/LCM_Dreamshaper_v7",
         
     | 
| 87 | 
         
            +
                    safety_checker=None,
         
     | 
| 88 | 
         
            +
                    controlnet=controlnet_canny,
         
     | 
| 89 | 
         
            +
                    scheduler=None,
         
     | 
| 90 | 
         
            +
                )
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
            if USE_TINY_AUTOENCODER:
         
     | 
| 93 | 
         
            +
                pipe.vae = AutoencoderTiny.from_pretrained(
         
     | 
| 94 | 
         
            +
                    "madebyollin/taesd", torch_dtype=torch_dtype, use_safetensors=True
         
     | 
| 95 | 
         
            +
                )
         
     | 
| 96 | 
         
            +
            pipe.set_progress_bar_config(disable=True)
         
     | 
| 97 | 
         
            +
            pipe.to(device=device, dtype=torch_dtype).to(device)
         
     | 
| 98 | 
         
            +
            pipe.unet.to(memory_format=torch.channels_last)
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
            if psutil.virtual_memory().total < 64 * 1024**3:
         
     | 
| 101 | 
         
            +
                pipe.enable_attention_slicing()
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
            # if not mps_available and not xpu_available:
         
     | 
| 104 | 
         
            +
            #     pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
         
     | 
| 105 | 
         
            +
            #     pipe(prompt="warmup", image=[Image.new("RGB", (512, 512))])
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
            compel_proc = Compel(
         
     | 
| 108 | 
         
            +
                tokenizer=pipe.tokenizer,
         
     | 
| 109 | 
         
            +
                text_encoder=pipe.text_encoder,
         
     | 
| 110 | 
         
            +
                truncate_long_prompts=False,
         
     | 
| 111 | 
         
            +
            )
         
     | 
| 112 | 
         
            +
            user_queue_map = {}
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
            class InputParams(BaseModel):
         
     | 
| 116 | 
         
            +
                seed: int = 2159232
         
     | 
| 117 | 
         
            +
                prompt: str
         
     | 
| 118 | 
         
            +
                guidance_scale: float = 8.0
         
     | 
| 119 | 
         
            +
                strength: float = 0.5
         
     | 
| 120 | 
         
            +
                steps: int = 4
         
     | 
| 121 | 
         
            +
                lcm_steps: int = 50
         
     | 
| 122 | 
         
            +
                width: int = WIDTH
         
     | 
| 123 | 
         
            +
                height: int = HEIGHT
         
     | 
| 124 | 
         
            +
                controlnet_scale: float = 0.8
         
     | 
| 125 | 
         
            +
                controlnet_start: float = 0.0
         
     | 
| 126 | 
         
            +
                controlnet_end: float = 1.0
         
     | 
| 127 | 
         
            +
                canny_low_threshold: float = 0.31
         
     | 
| 128 | 
         
            +
                canny_high_threshold: float = 0.78
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
            def predict(
         
     | 
| 131 | 
         
            +
                input_image: Image.Image, params: InputParams, prompt_embeds: torch.Tensor = None
         
     | 
| 132 | 
         
            +
            ):
         
     | 
| 133 | 
         
            +
                generator = torch.manual_seed(params.seed)
         
     | 
| 134 | 
         
            +
                
         
     | 
| 135 | 
         
            +
                control_image = canny_torch(input_image, params.canny_low_threshold, params.canny_high_threshold)
         
     | 
| 136 | 
         
            +
                print(params.canny_low_threshold, params.canny_high_threshold)
         
     | 
| 137 | 
         
            +
                results = pipe(
         
     | 
| 138 | 
         
            +
                    control_image=control_image,
         
     | 
| 139 | 
         
            +
                    prompt_embeds=prompt_embeds,
         
     | 
| 140 | 
         
            +
                    generator=generator,
         
     | 
| 141 | 
         
            +
                    image=input_image,
         
     | 
| 142 | 
         
            +
                    strength=params.strength,
         
     | 
| 143 | 
         
            +
                    num_inference_steps=params.steps,
         
     | 
| 144 | 
         
            +
                    guidance_scale=params.guidance_scale,
         
     | 
| 145 | 
         
            +
                    width=params.width,
         
     | 
| 146 | 
         
            +
                    height=params.height,
         
     | 
| 147 | 
         
            +
                    lcm_origin_steps=params.lcm_steps,
         
     | 
| 148 | 
         
            +
                    output_type="pil",
         
     | 
| 149 | 
         
            +
                    controlnet_conditioning_scale=params.controlnet_scale,
         
     | 
| 150 | 
         
            +
                    control_guidance_start=params.controlnet_start,
         
     | 
| 151 | 
         
            +
                    control_guidance_end=params.controlnet_end,
         
     | 
| 152 | 
         
            +
                )
         
     | 
| 153 | 
         
            +
                nsfw_content_detected = (
         
     | 
| 154 | 
         
            +
                    results.nsfw_content_detected[0]
         
     | 
| 155 | 
         
            +
                    if "nsfw_content_detected" in results
         
     | 
| 156 | 
         
            +
                    else False
         
     | 
| 157 | 
         
            +
                )
         
     | 
| 158 | 
         
            +
                if nsfw_content_detected:
         
     | 
| 159 | 
         
            +
                    return None
         
     | 
| 160 | 
         
            +
                return results.images[0]
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
            app = FastAPI()
         
     | 
| 164 | 
         
            +
            app.add_middleware(
         
     | 
| 165 | 
         
            +
                CORSMiddleware,
         
     | 
| 166 | 
         
            +
                allow_origins=["*"],
         
     | 
| 167 | 
         
            +
                allow_credentials=True,
         
     | 
| 168 | 
         
            +
                allow_methods=["*"],
         
     | 
| 169 | 
         
            +
                allow_headers=["*"],
         
     | 
| 170 | 
         
            +
            )
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
            @app.websocket("/ws")
         
     | 
| 174 | 
         
            +
            async def websocket_endpoint(websocket: WebSocket):
         
     | 
| 175 | 
         
            +
                await websocket.accept()
         
     | 
| 176 | 
         
            +
                if MAX_QUEUE_SIZE > 0 and len(user_queue_map) >= MAX_QUEUE_SIZE:
         
     | 
| 177 | 
         
            +
                    print("Server is full")
         
     | 
| 178 | 
         
            +
                    await websocket.send_json({"status": "error", "message": "Server is full"})
         
     | 
| 179 | 
         
            +
                    await websocket.close()
         
     | 
| 180 | 
         
            +
                    return
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                try:
         
     | 
| 183 | 
         
            +
                    uid = str(uuid.uuid4())
         
     | 
| 184 | 
         
            +
                    print(f"New user connected: {uid}")
         
     | 
| 185 | 
         
            +
                    await websocket.send_json(
         
     | 
| 186 | 
         
            +
                        {"status": "success", "message": "Connected", "userId": uid}
         
     | 
| 187 | 
         
            +
                    )
         
     | 
| 188 | 
         
            +
                    user_queue_map[uid] = {"queue": asyncio.Queue()}
         
     | 
| 189 | 
         
            +
                    await websocket.send_json(
         
     | 
| 190 | 
         
            +
                        {"status": "start", "message": "Start Streaming", "userId": uid}
         
     | 
| 191 | 
         
            +
                    )
         
     | 
| 192 | 
         
            +
                    await handle_websocket_data(websocket, uid)
         
     | 
| 193 | 
         
            +
                except WebSocketDisconnect as e:
         
     | 
| 194 | 
         
            +
                    logging.error(f"WebSocket Error: {e}, {uid}")
         
     | 
| 195 | 
         
            +
                    traceback.print_exc()
         
     | 
| 196 | 
         
            +
                finally:
         
     | 
| 197 | 
         
            +
                    print(f"User disconnected: {uid}")
         
     | 
| 198 | 
         
            +
                    queue_value = user_queue_map.pop(uid, None)
         
     | 
| 199 | 
         
            +
                    queue = queue_value.get("queue", None)
         
     | 
| 200 | 
         
            +
                    if queue:
         
     | 
| 201 | 
         
            +
                        while not queue.empty():
         
     | 
| 202 | 
         
            +
                            try:
         
     | 
| 203 | 
         
            +
                                queue.get_nowait()
         
     | 
| 204 | 
         
            +
                            except asyncio.QueueEmpty:
         
     | 
| 205 | 
         
            +
                                continue
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
            @app.get("/queue_size")
         
     | 
| 209 | 
         
            +
            async def get_queue_size():
         
     | 
| 210 | 
         
            +
                queue_size = len(user_queue_map)
         
     | 
| 211 | 
         
            +
                return JSONResponse({"queue_size": queue_size})
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
            @app.get("/stream/{user_id}")
         
     | 
| 215 | 
         
            +
            async def stream(user_id: uuid.UUID):
         
     | 
| 216 | 
         
            +
                uid = str(user_id)
         
     | 
| 217 | 
         
            +
                try:
         
     | 
| 218 | 
         
            +
                    user_queue = user_queue_map[uid]
         
     | 
| 219 | 
         
            +
                    queue = user_queue["queue"]
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                    async def generate():
         
     | 
| 222 | 
         
            +
                        last_prompt: str = None
         
     | 
| 223 | 
         
            +
                        prompt_embeds: torch.Tensor = None
         
     | 
| 224 | 
         
            +
                        while True:
         
     | 
| 225 | 
         
            +
                            data = await queue.get()
         
     | 
| 226 | 
         
            +
                            input_image = data["image"]
         
     | 
| 227 | 
         
            +
                            params = data["params"]
         
     | 
| 228 | 
         
            +
                            if input_image is None:
         
     | 
| 229 | 
         
            +
                                continue
         
     | 
| 230 | 
         
            +
                            # avoid recalculate prompt embeds
         
     | 
| 231 | 
         
            +
                            if last_prompt != params.prompt:
         
     | 
| 232 | 
         
            +
                                print("new prompt")
         
     | 
| 233 | 
         
            +
                                prompt_embeds = compel_proc(params.prompt)
         
     | 
| 234 | 
         
            +
                                last_prompt = params.prompt
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
                            image = predict(
         
     | 
| 237 | 
         
            +
                                input_image,
         
     | 
| 238 | 
         
            +
                                params,
         
     | 
| 239 | 
         
            +
                                prompt_embeds,
         
     | 
| 240 | 
         
            +
                            )
         
     | 
| 241 | 
         
            +
                            if image is None:
         
     | 
| 242 | 
         
            +
                                continue
         
     | 
| 243 | 
         
            +
                            frame_data = io.BytesIO()
         
     | 
| 244 | 
         
            +
                            image.save(frame_data, format="JPEG")
         
     | 
| 245 | 
         
            +
                            frame_data = frame_data.getvalue()
         
     | 
| 246 | 
         
            +
                            if frame_data is not None and len(frame_data) > 0:
         
     | 
| 247 | 
         
            +
                                yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame_data + b"\r\n"
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
                            await asyncio.sleep(1.0 / 120.0)
         
     | 
| 250 | 
         
            +
             
     | 
| 251 | 
         
            +
                    return StreamingResponse(
         
     | 
| 252 | 
         
            +
                        generate(), media_type="multipart/x-mixed-replace;boundary=frame"
         
     | 
| 253 | 
         
            +
                    )
         
     | 
| 254 | 
         
            +
                except Exception as e:
         
     | 
| 255 | 
         
            +
                    logging.error(f"Streaming Error: {e}, {user_queue_map}")
         
     | 
| 256 | 
         
            +
                    traceback.print_exc()
         
     | 
| 257 | 
         
            +
                    return HTTPException(status_code=404, detail="User not found")
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
            async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
         
     | 
| 261 | 
         
            +
                uid = str(user_id)
         
     | 
| 262 | 
         
            +
                user_queue = user_queue_map[uid]
         
     | 
| 263 | 
         
            +
                queue = user_queue["queue"]
         
     | 
| 264 | 
         
            +
                if not queue:
         
     | 
| 265 | 
         
            +
                    return HTTPException(status_code=404, detail="User not found")
         
     | 
| 266 | 
         
            +
                last_time = time.time()
         
     | 
| 267 | 
         
            +
                try:
         
     | 
| 268 | 
         
            +
                    while True:
         
     | 
| 269 | 
         
            +
                        data = await websocket.receive_bytes()
         
     | 
| 270 | 
         
            +
                        params = await websocket.receive_json()
         
     | 
| 271 | 
         
            +
                        params = InputParams(**params)
         
     | 
| 272 | 
         
            +
                        pil_image = Image.open(io.BytesIO(data))
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
                        while not queue.empty():
         
     | 
| 275 | 
         
            +
                            try:
         
     | 
| 276 | 
         
            +
                                queue.get_nowait()
         
     | 
| 277 | 
         
            +
                            except asyncio.QueueEmpty:
         
     | 
| 278 | 
         
            +
                                continue
         
     | 
| 279 | 
         
            +
                        await queue.put({"image": pil_image, "params": params})
         
     | 
| 280 | 
         
            +
                        if TIMEOUT > 0 and time.time() - last_time > TIMEOUT:
         
     | 
| 281 | 
         
            +
                            await websocket.send_json(
         
     | 
| 282 | 
         
            +
                                {
         
     | 
| 283 | 
         
            +
                                    "status": "timeout",
         
     | 
| 284 | 
         
            +
                                    "message": "Your session has ended",
         
     | 
| 285 | 
         
            +
                                    "userId": uid,
         
     | 
| 286 | 
         
            +
                                }
         
     | 
| 287 | 
         
            +
                            )
         
     | 
| 288 | 
         
            +
                            await websocket.close()
         
     | 
| 289 | 
         
            +
                            return
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                except Exception as e:
         
     | 
| 292 | 
         
            +
                    logging.error(f"Error: {e}")
         
     | 
| 293 | 
         
            +
                    traceback.print_exc()
         
     | 
| 294 | 
         
            +
             
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
            app.mount("/", StaticFiles(directory="controlnet", html=True), name="public")
         
     | 
    	
        canny_gpu.py
    ADDED
    
    | 
         @@ -0,0 +1,44 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import torch.nn as nn
         
     | 
| 3 | 
         
            +
            from torchvision.transforms import ToTensor, ToPILImage
         
     | 
| 4 | 
         
            +
            from PIL import Image
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            class SobelOperator(nn.Module):
         
     | 
| 7 | 
         
            +
                def __init__(self, device="cuda"):
         
     | 
| 8 | 
         
            +
                    super(SobelOperator, self).__init__()
         
     | 
| 9 | 
         
            +
                    self.device = device
         
     | 
| 10 | 
         
            +
                    self.edge_conv_x = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False).to(
         
     | 
| 11 | 
         
            +
                        self.device
         
     | 
| 12 | 
         
            +
                    )
         
     | 
| 13 | 
         
            +
                    self.edge_conv_y = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False).to(
         
     | 
| 14 | 
         
            +
                        self.device
         
     | 
| 15 | 
         
            +
                    )
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                    sobel_kernel_x = torch.tensor(
         
     | 
| 18 | 
         
            +
                        [[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]], device=self.device
         
     | 
| 19 | 
         
            +
                    )
         
     | 
| 20 | 
         
            +
                    sobel_kernel_y = torch.tensor(
         
     | 
| 21 | 
         
            +
                        [[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]], device=self.device
         
     | 
| 22 | 
         
            +
                    )
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                    self.edge_conv_x.weight = nn.Parameter(sobel_kernel_x.view((1, 1, 3, 3)))
         
     | 
| 25 | 
         
            +
                    self.edge_conv_y.weight = nn.Parameter(sobel_kernel_y.view((1, 1, 3, 3)))
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                @torch.no_grad()
         
     | 
| 28 | 
         
            +
                def forward(self, image: Image.Image, low_threshold: float, high_threshold: float):
         
     | 
| 29 | 
         
            +
                    # Convert PIL image to PyTorch tensor
         
     | 
| 30 | 
         
            +
                    image_gray = image.convert("L")
         
     | 
| 31 | 
         
            +
                    image_tensor = ToTensor()(image_gray).unsqueeze(0).to(self.device)
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                    # Compute gradients
         
     | 
| 34 | 
         
            +
                    edge_x = self.edge_conv_x(image_tensor)
         
     | 
| 35 | 
         
            +
                    edge_y = self.edge_conv_y(image_tensor)
         
     | 
| 36 | 
         
            +
                    edge = torch.sqrt(edge_x**2 + edge_y**2)
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                    # Apply thresholding
         
     | 
| 39 | 
         
            +
                    edge = edge / edge.max()  # Normalize to 0-1
         
     | 
| 40 | 
         
            +
                    edge[edge >= high_threshold] = 1.0
         
     | 
| 41 | 
         
            +
                    edge[edge <= low_threshold] = 0.0
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                    # Convert the result back to a PIL image
         
     | 
| 44 | 
         
            +
                    return ToPILImage()(edge.squeeze(0).cpu())
         
     | 
    	
        controlnet/index.html
    ADDED
    
    | 
         @@ -0,0 +1,414 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            <!doctype html>
         
     | 
| 2 | 
         
            +
            <html>
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            <head>
         
     | 
| 5 | 
         
            +
                <meta charset="UTF-8">
         
     | 
| 6 | 
         
            +
                <title>Real-Time Latent Consistency Model ControlNet</title>
         
     | 
| 7 | 
         
            +
                <meta name="viewport" content="width=device-width, initial-scale=1.0">
         
     | 
| 8 | 
         
            +
                <script
         
     | 
| 9 | 
         
            +
                    src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script>
         
     | 
| 10 | 
         
            +
                <script src="https://cdn.jsdelivr.net/npm/piexifjs@1.0.6/piexif.min.js"></script>
         
     | 
| 11 | 
         
            +
                <script src="https://cdn.tailwindcss.com"></script>
         
     | 
| 12 | 
         
            +
                <style type="text/tailwindcss">
         
     | 
| 13 | 
         
            +
                    .button {
         
     | 
| 14 | 
         
            +
                      @apply bg-gray-700 hover:bg-gray-800 text-white font-normal p-2 rounded disabled:bg-gray-300 dark:disabled:bg-gray-700 disabled:cursor-not-allowed dark:disabled:text-black
         
     | 
| 15 | 
         
            +
                    }
         
     | 
| 16 | 
         
            +
                </style>
         
     | 
| 17 | 
         
            +
                <script type="module">
         
     | 
| 18 | 
         
            +
                    // you can change the size of the input image to 768x768 if you have a powerful GPU
         
     | 
| 19 | 
         
            +
                    const getValue = (id) => document.querySelector(`${id}`).value;
         
     | 
| 20 | 
         
            +
                    const startBtn = document.querySelector("#start");
         
     | 
| 21 | 
         
            +
                    const stopBtn = document.querySelector("#stop");
         
     | 
| 22 | 
         
            +
                    const videoEl = document.querySelector("#webcam");
         
     | 
| 23 | 
         
            +
                    const imageEl = document.querySelector("#player");
         
     | 
| 24 | 
         
            +
                    const queueSizeEl = document.querySelector("#queue_size");
         
     | 
| 25 | 
         
            +
                    const errorEl = document.querySelector("#error");
         
     | 
| 26 | 
         
            +
                    const snapBtn = document.querySelector("#snap");
         
     | 
| 27 | 
         
            +
                    const webcamsEl = document.querySelector("#webcams");
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                    function LCMLive(webcamVideo, liveImage) {
         
     | 
| 30 | 
         
            +
                        let websocket;
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                        async function start() {
         
     | 
| 33 | 
         
            +
                            return new Promise((resolve, reject) => {
         
     | 
| 34 | 
         
            +
                                const websocketURL = `${window.location.protocol === "https:" ? "wss" : "ws"
         
     | 
| 35 | 
         
            +
                                    }:${window.location.host}/ws`;
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                                const socket = new WebSocket(websocketURL);
         
     | 
| 38 | 
         
            +
                                socket.onopen = () => {
         
     | 
| 39 | 
         
            +
                                    console.log("Connected to websocket");
         
     | 
| 40 | 
         
            +
                                };
         
     | 
| 41 | 
         
            +
                                socket.onclose = () => {
         
     | 
| 42 | 
         
            +
                                    console.log("Disconnected from websocket");
         
     | 
| 43 | 
         
            +
                                    stop();
         
     | 
| 44 | 
         
            +
                                    resolve({ "status": "disconnected" });
         
     | 
| 45 | 
         
            +
                                };
         
     | 
| 46 | 
         
            +
                                socket.onerror = (err) => {
         
     | 
| 47 | 
         
            +
                                    console.error(err);
         
     | 
| 48 | 
         
            +
                                    reject(err);
         
     | 
| 49 | 
         
            +
                                };
         
     | 
| 50 | 
         
            +
                                socket.onmessage = (event) => {
         
     | 
| 51 | 
         
            +
                                    const data = JSON.parse(event.data);
         
     | 
| 52 | 
         
            +
                                    switch (data.status) {
         
     | 
| 53 | 
         
            +
                                        case "success":
         
     | 
| 54 | 
         
            +
                                            break;
         
     | 
| 55 | 
         
            +
                                        case "start":
         
     | 
| 56 | 
         
            +
                                            const userId = data.userId;
         
     | 
| 57 | 
         
            +
                                            initVideoStream(userId);
         
     | 
| 58 | 
         
            +
                                            break;
         
     | 
| 59 | 
         
            +
                                        case "timeout":
         
     | 
| 60 | 
         
            +
                                            stop();
         
     | 
| 61 | 
         
            +
                                            resolve({ "status": "timeout" });
         
     | 
| 62 | 
         
            +
                                        case "error":
         
     | 
| 63 | 
         
            +
                                            stop();
         
     | 
| 64 | 
         
            +
                                            reject(data.message);
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                                    }
         
     | 
| 67 | 
         
            +
                                };
         
     | 
| 68 | 
         
            +
                                websocket = socket;
         
     | 
| 69 | 
         
            +
                            })
         
     | 
| 70 | 
         
            +
                        }
         
     | 
| 71 | 
         
            +
                        function switchCamera() {
         
     | 
| 72 | 
         
            +
                            const constraints = {
         
     | 
| 73 | 
         
            +
                                audio: false,
         
     | 
| 74 | 
         
            +
                                video: { width: 1024, height: 768, deviceId: mediaDevices[webcamsEl.value].deviceId }
         
     | 
| 75 | 
         
            +
                            };
         
     | 
| 76 | 
         
            +
                            navigator.mediaDevices
         
     | 
| 77 | 
         
            +
                                .getUserMedia(constraints)
         
     | 
| 78 | 
         
            +
                                .then((mediaStream) => {
         
     | 
| 79 | 
         
            +
                                    webcamVideo.removeEventListener("timeupdate", videoTimeUpdateHandler);
         
     | 
| 80 | 
         
            +
                                    webcamVideo.srcObject = mediaStream;
         
     | 
| 81 | 
         
            +
                                    webcamVideo.onloadedmetadata = () => {
         
     | 
| 82 | 
         
            +
                                        webcamVideo.play();
         
     | 
| 83 | 
         
            +
                                        webcamVideo.addEventListener("timeupdate", videoTimeUpdateHandler);
         
     | 
| 84 | 
         
            +
                                    };
         
     | 
| 85 | 
         
            +
                                })
         
     | 
| 86 | 
         
            +
                                .catch((err) => {
         
     | 
| 87 | 
         
            +
                                    console.error(`${err.name}: ${err.message}`);
         
     | 
| 88 | 
         
            +
                                });
         
     | 
| 89 | 
         
            +
                        }
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                        async function videoTimeUpdateHandler() {
         
     | 
| 92 | 
         
            +
                            const dimension = getValue("input[name=dimension]:checked");
         
     | 
| 93 | 
         
            +
                            const [WIDTH, HEIGHT] = JSON.parse(dimension);
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                            const canvas = new OffscreenCanvas(WIDTH, HEIGHT);
         
     | 
| 96 | 
         
            +
                            const videoW = webcamVideo.videoWidth;
         
     | 
| 97 | 
         
            +
                            const videoH = webcamVideo.videoHeight;
         
     | 
| 98 | 
         
            +
                            const aspectRatio = WIDTH / HEIGHT;
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                            const ctx = canvas.getContext("2d");
         
     | 
| 101 | 
         
            +
                            ctx.drawImage(webcamVideo, videoW / 2 - videoH * aspectRatio / 2, 0, videoH * aspectRatio, videoH, 0, 0, WIDTH, HEIGHT)
         
     | 
| 102 | 
         
            +
                            const blob = await canvas.convertToBlob({ type: "image/jpeg", quality: 1 });
         
     | 
| 103 | 
         
            +
                            websocket.send(blob);
         
     | 
| 104 | 
         
            +
                            websocket.send(JSON.stringify({
         
     | 
| 105 | 
         
            +
                                "seed": getValue("#seed"),
         
     | 
| 106 | 
         
            +
                                "prompt": getValue("#prompt"),
         
     | 
| 107 | 
         
            +
                                "guidance_scale": getValue("#guidance-scale"),
         
     | 
| 108 | 
         
            +
                                "strength": getValue("#strength"),
         
     | 
| 109 | 
         
            +
                                "steps": getValue("#steps"),
         
     | 
| 110 | 
         
            +
                                "lcm_steps": getValue("#lcm_steps"),
         
     | 
| 111 | 
         
            +
                                "width": WIDTH,
         
     | 
| 112 | 
         
            +
                                "height": HEIGHT,
         
     | 
| 113 | 
         
            +
                                "controlnet_scale": getValue("#controlnet_scale"),
         
     | 
| 114 | 
         
            +
                                "controlnet_start": getValue("#controlnet_start"),
         
     | 
| 115 | 
         
            +
                                "controlnet_end": getValue("#controlnet_end"),
         
     | 
| 116 | 
         
            +
                                "canny_low_threshold": getValue("#canny_low_threshold"),
         
     | 
| 117 | 
         
            +
                                "canny_high_threshold": getValue("#canny_high_threshold"),
         
     | 
| 118 | 
         
            +
                            }));
         
     | 
| 119 | 
         
            +
                        }
         
     | 
| 120 | 
         
            +
                        let mediaDevices = [];
         
     | 
| 121 | 
         
            +
                        async function initVideoStream(userId) {
         
     | 
| 122 | 
         
            +
                            liveImage.src = `/stream/${userId}`;
         
     | 
| 123 | 
         
            +
                            await navigator.mediaDevices.enumerateDevices()
         
     | 
| 124 | 
         
            +
                                .then(devices => {
         
     | 
| 125 | 
         
            +
                                    const cameras = devices.filter(device => device.kind === 'videoinput');
         
     | 
| 126 | 
         
            +
                                    mediaDevices = cameras;
         
     | 
| 127 | 
         
            +
                                    webcamsEl.innerHTML = "";
         
     | 
| 128 | 
         
            +
                                    cameras.forEach((camera, index) => {
         
     | 
| 129 | 
         
            +
                                        const option = document.createElement("option");
         
     | 
| 130 | 
         
            +
                                        option.value = index;
         
     | 
| 131 | 
         
            +
                                        option.innerText = camera.label;
         
     | 
| 132 | 
         
            +
                                        webcamsEl.appendChild(option);
         
     | 
| 133 | 
         
            +
                                        option.selected = index === 0;
         
     | 
| 134 | 
         
            +
                                    });
         
     | 
| 135 | 
         
            +
                                    webcamsEl.addEventListener("change", switchCamera);
         
     | 
| 136 | 
         
            +
                                })
         
     | 
| 137 | 
         
            +
                                .catch(err => {
         
     | 
| 138 | 
         
            +
                                    console.error(err);
         
     | 
| 139 | 
         
            +
                                });
         
     | 
| 140 | 
         
            +
                            const constraints = {
         
     | 
| 141 | 
         
            +
                                audio: false,
         
     | 
| 142 | 
         
            +
                                video: { width: 1024, height: 768, deviceId: mediaDevices[0].deviceId }
         
     | 
| 143 | 
         
            +
                            };
         
     | 
| 144 | 
         
            +
                            navigator.mediaDevices
         
     | 
| 145 | 
         
            +
                                .getUserMedia(constraints)
         
     | 
| 146 | 
         
            +
                                .then((mediaStream) => {
         
     | 
| 147 | 
         
            +
                                    webcamVideo.srcObject = mediaStream;
         
     | 
| 148 | 
         
            +
                                    webcamVideo.onloadedmetadata = () => {
         
     | 
| 149 | 
         
            +
                                        webcamVideo.play();
         
     | 
| 150 | 
         
            +
                                        webcamVideo.addEventListener("timeupdate", videoTimeUpdateHandler);
         
     | 
| 151 | 
         
            +
                                    };
         
     | 
| 152 | 
         
            +
                                })
         
     | 
| 153 | 
         
            +
                                .catch((err) => {
         
     | 
| 154 | 
         
            +
                                    console.error(`${err.name}: ${err.message}`);
         
     | 
| 155 | 
         
            +
                                });
         
     | 
| 156 | 
         
            +
                        }
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                        async function stop() {
         
     | 
| 160 | 
         
            +
                            websocket.close();
         
     | 
| 161 | 
         
            +
                            navigator.mediaDevices.getUserMedia({ video: true }).then((mediaStream) => {
         
     | 
| 162 | 
         
            +
                                mediaStream.getTracks().forEach((track) => track.stop());
         
     | 
| 163 | 
         
            +
                            });
         
     | 
| 164 | 
         
            +
                            webcamVideo.removeEventListener("timeupdate", videoTimeUpdateHandler);
         
     | 
| 165 | 
         
            +
                            webcamsEl.removeEventListener("change", switchCamera);
         
     | 
| 166 | 
         
            +
                            webcamVideo.srcObject = null;
         
     | 
| 167 | 
         
            +
                        }
         
     | 
| 168 | 
         
            +
                        return {
         
     | 
| 169 | 
         
            +
                            start,
         
     | 
| 170 | 
         
            +
                            stop
         
     | 
| 171 | 
         
            +
                        }
         
     | 
| 172 | 
         
            +
                    }
         
     | 
| 173 | 
         
            +
                    function toggleMessage(type) {
         
     | 
| 174 | 
         
            +
                        errorEl.hidden = false;
         
     | 
| 175 | 
         
            +
                        errorEl.scrollIntoView();
         
     | 
| 176 | 
         
            +
                        switch (type) {
         
     | 
| 177 | 
         
            +
                            case "error":
         
     | 
| 178 | 
         
            +
                                errorEl.innerText = "To many users are using the same GPU, please try again later.";
         
     | 
| 179 | 
         
            +
                                errorEl.classList.toggle("bg-red-300", "text-red-900");
         
     | 
| 180 | 
         
            +
                                break;
         
     | 
| 181 | 
         
            +
                            case "success":
         
     | 
| 182 | 
         
            +
                                errorEl.innerText = "Your session has ended, please start a new one.";
         
     | 
| 183 | 
         
            +
                                errorEl.classList.toggle("bg-green-300", "text-green-900");
         
     | 
| 184 | 
         
            +
                                break;
         
     | 
| 185 | 
         
            +
                        }
         
     | 
| 186 | 
         
            +
                        setTimeout(() => {
         
     | 
| 187 | 
         
            +
                            errorEl.hidden = true;
         
     | 
| 188 | 
         
            +
                        }, 2000);
         
     | 
| 189 | 
         
            +
                    }
         
     | 
| 190 | 
         
            +
                    function snapImage() {
         
     | 
| 191 | 
         
            +
                        try {
         
     | 
| 192 | 
         
            +
                            const zeroth = {};
         
     | 
| 193 | 
         
            +
                            const exif = {};
         
     | 
| 194 | 
         
            +
                            const gps = {};
         
     | 
| 195 | 
         
            +
                            zeroth[piexif.ImageIFD.Make] = "LCM Image-to-Image ControNet";
         
     | 
| 196 | 
         
            +
                            zeroth[piexif.ImageIFD.ImageDescription] = `prompt: ${getValue("#prompt")} | seed: ${getValue("#seed")} | guidance_scale: ${getValue("#guidance-scale")} | strength: ${getValue("#strength")} | controlnet_start: ${getValue("#controlnet_start")} | controlnet_end: ${getValue("#controlnet_end")} | lcm_steps: ${getValue("#lcm_steps")} | steps: ${getValue("#steps")}`;
         
     | 
| 197 | 
         
            +
                            zeroth[piexif.ImageIFD.Software] = "https://github.com/radames/Real-Time-Latent-Consistency-Model";
         
     | 
| 198 | 
         
            +
                            exif[piexif.ExifIFD.DateTimeOriginal] = new Date().toISOString();
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                            const exifObj = { "0th": zeroth, "Exif": exif, "GPS": gps };
         
     | 
| 201 | 
         
            +
                            const exifBytes = piexif.dump(exifObj);
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
                            const canvas = document.createElement("canvas");
         
     | 
| 204 | 
         
            +
                            canvas.width = imageEl.naturalWidth;
         
     | 
| 205 | 
         
            +
                            canvas.height = imageEl.naturalHeight;
         
     | 
| 206 | 
         
            +
                            const ctx = canvas.getContext("2d");
         
     | 
| 207 | 
         
            +
                            ctx.drawImage(imageEl, 0, 0);
         
     | 
| 208 | 
         
            +
                            const dataURL = canvas.toDataURL("image/jpeg");
         
     | 
| 209 | 
         
            +
                            const withExif = piexif.insert(exifBytes, dataURL);
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
                            const a = document.createElement("a");
         
     | 
| 212 | 
         
            +
                            a.href = withExif;
         
     | 
| 213 | 
         
            +
                            a.download = `lcm_txt_2_img${Date.now()}.png`;
         
     | 
| 214 | 
         
            +
                            a.click();
         
     | 
| 215 | 
         
            +
                        } catch (err) {
         
     | 
| 216 | 
         
            +
                            console.log(err);
         
     | 
| 217 | 
         
            +
                        }
         
     | 
| 218 | 
         
            +
                    }
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                    const lcmLive = LCMLive(videoEl, imageEl);
         
     | 
| 222 | 
         
            +
                    startBtn.addEventListener("click", async () => {
         
     | 
| 223 | 
         
            +
                        try {
         
     | 
| 224 | 
         
            +
                            startBtn.disabled = true;
         
     | 
| 225 | 
         
            +
                            snapBtn.disabled = false;
         
     | 
| 226 | 
         
            +
                            const res = await lcmLive.start();
         
     | 
| 227 | 
         
            +
                            startBtn.disabled = false;
         
     | 
| 228 | 
         
            +
                            if (res.status === "timeout")
         
     | 
| 229 | 
         
            +
                                toggleMessage("success")
         
     | 
| 230 | 
         
            +
                        } catch (err) {
         
     | 
| 231 | 
         
            +
                            console.log(err);
         
     | 
| 232 | 
         
            +
                            toggleMessage("error")
         
     | 
| 233 | 
         
            +
                            startBtn.disabled = false;
         
     | 
| 234 | 
         
            +
                        }
         
     | 
| 235 | 
         
            +
                    });
         
     | 
| 236 | 
         
            +
                    stopBtn.addEventListener("click", () => {
         
     | 
| 237 | 
         
            +
                        lcmLive.stop();
         
     | 
| 238 | 
         
            +
                    });
         
     | 
| 239 | 
         
            +
                    window.addEventListener("beforeunload", () => {
         
     | 
| 240 | 
         
            +
                        lcmLive.stop();
         
     | 
| 241 | 
         
            +
                    });
         
     | 
| 242 | 
         
            +
                    snapBtn.addEventListener("click", snapImage);
         
     | 
| 243 | 
         
            +
                    setInterval(() =>
         
     | 
| 244 | 
         
            +
                        fetch("/queue_size")
         
     | 
| 245 | 
         
            +
                            .then((res) => res.json())
         
     | 
| 246 | 
         
            +
                            .then((data) => {
         
     | 
| 247 | 
         
            +
                                queueSizeEl.innerText = data.queue_size;
         
     | 
| 248 | 
         
            +
                            })
         
     | 
| 249 | 
         
            +
                            .catch((err) => {
         
     | 
| 250 | 
         
            +
                                console.log(err);
         
     | 
| 251 | 
         
            +
                            })
         
     | 
| 252 | 
         
            +
                        , 5000);
         
     | 
| 253 | 
         
            +
                </script>
         
     | 
| 254 | 
         
            +
            </head>
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
            <body class="text-black dark:bg-gray-900 dark:text-white">
         
     | 
| 257 | 
         
            +
                <div class="fixed right-2 top-2 p-4 font-bold text-sm rounded-lg max-w-xs text-center" id="error">
         
     | 
| 258 | 
         
            +
                </div>
         
     | 
| 259 | 
         
            +
                <main class="container mx-auto px-4 py-4 max-w-4xl flex flex-col gap-4">
         
     | 
| 260 | 
         
            +
                    <article class="text-center max-w-xl mx-auto">
         
     | 
| 261 | 
         
            +
                        <h1 class="text-3xl font-bold">Real-Time Latent Consistency Model</h1>
         
     | 
| 262 | 
         
            +
                        <h2 class="text-2xl font-bold mb-4">ControlNet</h2>
         
     | 
| 263 | 
         
            +
                        <p class="text-sm">
         
     | 
| 264 | 
         
            +
                            This demo showcases
         
     | 
| 265 | 
         
            +
                            <a href="https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7" target="_blank"
         
     | 
| 266 | 
         
            +
                                class="text-blue-500 underline hover:no-underline">LCM</a> Image to Image pipeline
         
     | 
| 267 | 
         
            +
                            using
         
     | 
| 268 | 
         
            +
                            <a href="https://github.com/huggingface/diffusers/tree/main/examples/community#latent-consistency-pipeline"
         
     | 
| 269 | 
         
            +
                                target="_blank" class="text-blue-500 underline hover:no-underline">Diffusers</a> with a MJPEG
         
     | 
| 270 | 
         
            +
                            stream server.
         
     | 
| 271 | 
         
            +
                        </p>
         
     | 
| 272 | 
         
            +
                        <p class="text-sm">
         
     | 
| 273 | 
         
            +
                            There are <span id="queue_size" class="font-bold">0</span> user(s) sharing the same GPU, affecting
         
     | 
| 274 | 
         
            +
                            real-time performance. Maximum queue size is 4. <a
         
     | 
| 275 | 
         
            +
                                href="https://huggingface.co/spaces/radames/Real-Time-Latent-Consistency-Model?duplicate=true"
         
     | 
| 276 | 
         
            +
                                target="_blank" class="text-blue-500 underline hover:no-underline">Duplicate</a> and run it on your
         
     | 
| 277 | 
         
            +
                            own GPU.
         
     | 
| 278 | 
         
            +
                        </p>
         
     | 
| 279 | 
         
            +
                    </article>
         
     | 
| 280 | 
         
            +
                    <div>
         
     | 
| 281 | 
         
            +
                        <h2 class="font-medium">Prompt</h2>
         
     | 
| 282 | 
         
            +
                        <p class="text-sm text-gray-500">
         
     | 
| 283 | 
         
            +
                            Change the prompt to generate different images, accepts <a
         
     | 
| 284 | 
         
            +
                                href="https://github.com/damian0815/compel/blob/main/doc/syntax.md" target="_blank"
         
     | 
| 285 | 
         
            +
                                class="text-blue-500 underline hover:no-underline">Compel</a> syntax.
         
     | 
| 286 | 
         
            +
                        </p>
         
     | 
| 287 | 
         
            +
                        <div class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center">
         
     | 
| 288 | 
         
            +
                            <textarea type="text" id="prompt" class="font-light w-full px-3 py-2 mx-1  outline-none dark:text-black"
         
     | 
| 289 | 
         
            +
                                title="Prompt, this is an example, feel free to modify"
         
     | 
| 290 | 
         
            +
                                placeholder="Add your prompt here...">Portrait of The Terminator with , glare pose, detailed, intricate, full of colour, cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details, unreal engine 5, cinematic, masterpiece</textarea>
         
     | 
| 291 | 
         
            +
                        </div>
         
     | 
| 292 | 
         
            +
                    </div>
         
     | 
| 293 | 
         
            +
                    <div class="">
         
     | 
| 294 | 
         
            +
                        <details>
         
     | 
| 295 | 
         
            +
                            <summary class="font-medium cursor-pointer">Advanced Options</summary>
         
     | 
| 296 | 
         
            +
                            <div class="grid grid-cols-3 sm:grid-cols-6 items-center gap-3 py-3">
         
     | 
| 297 | 
         
            +
                                <label for="webcams" class="text-sm font-medium">Camera Options: </label>
         
     | 
| 298 | 
         
            +
                                <select id="webcams" class="text-sm border-2 border-gray-500 rounded-md font-light dark:text-black">
         
     | 
| 299 | 
         
            +
                                </select>
         
     | 
| 300 | 
         
            +
                                <div></div>
         
     | 
| 301 | 
         
            +
                                <label class="text-sm font-medium " for="steps">Inference Steps
         
     | 
| 302 | 
         
            +
                                </label>
         
     | 
| 303 | 
         
            +
                                <input type="range" id="steps" name="steps" min="1" max="20" value="4"
         
     | 
| 304 | 
         
            +
                                    oninput="this.nextElementSibling.value = Number(this.value)">
         
     | 
| 305 | 
         
            +
                                <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
         
     | 
| 306 | 
         
            +
                                    4</output>
         
     | 
| 307 | 
         
            +
                                <!--  -->
         
     | 
| 308 | 
         
            +
                                <label class="text-sm font-medium" for="lcm_steps">LCM Inference Steps
         
     | 
| 309 | 
         
            +
                                </label>
         
     | 
| 310 | 
         
            +
                                <input type="range" id="lcm_steps" name="lcm_steps" min="2" max="60" value="50"
         
     | 
| 311 | 
         
            +
                                    oninput="this.nextElementSibling.value = Number(this.value)">
         
     | 
| 312 | 
         
            +
                                <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
         
     | 
| 313 | 
         
            +
                                    50</output>
         
     | 
| 314 | 
         
            +
                                <!--  -->
         
     | 
| 315 | 
         
            +
                                <label class="text-sm font-medium" for="guidance-scale">Guidance Scale
         
     | 
| 316 | 
         
            +
                                </label>
         
     | 
| 317 | 
         
            +
                                <input type="range" id="guidance-scale" name="guidance-scale" min="0" max="30" step="0.001"
         
     | 
| 318 | 
         
            +
                                    value="8.0" oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
         
     | 
| 319 | 
         
            +
                                <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
         
     | 
| 320 | 
         
            +
                                    8.0</output>
         
     | 
| 321 | 
         
            +
                                <!--  -->
         
     | 
| 322 | 
         
            +
                                <label class="text-sm font-medium" for="strength">Strength</label>
         
     | 
| 323 | 
         
            +
                                <input type="range" id="strength" name="strength" min="0.1" max="1" step="0.001" value="0.50"
         
     | 
| 324 | 
         
            +
                                    oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
         
     | 
| 325 | 
         
            +
                                <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
         
     | 
| 326 | 
         
            +
                                    0.5</output>
         
     | 
| 327 | 
         
            +
                                <!--  -->
         
     | 
| 328 | 
         
            +
                                <label class="text-sm font-medium" for="controlnet_scale">ControlNet Condition Scale</label>
         
     | 
| 329 | 
         
            +
                                <input type="range" id="controlnet_scale" name="controlnet_scale" min="0.0" max="1" step="0.001"
         
     | 
| 330 | 
         
            +
                                    value="0.80" oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
         
     | 
| 331 | 
         
            +
                                <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
         
     | 
| 332 | 
         
            +
                                    0.8</output>
         
     | 
| 333 | 
         
            +
                                <!--  -->
         
     | 
| 334 | 
         
            +
                                <label class="text-sm font-medium" for="controlnet_start">ControlNet Guidance Start</label>
         
     | 
| 335 | 
         
            +
                                <input type="range" id="controlnet_start" name="controlnet_start" min="0.0" max="1.0" step="0.001"
         
     | 
| 336 | 
         
            +
                                    value="0.0" oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
         
     | 
| 337 | 
         
            +
                                <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
         
     | 
| 338 | 
         
            +
                                    0.0</output>
         
     | 
| 339 | 
         
            +
                                <!--  -->
         
     | 
| 340 | 
         
            +
                                <label class="text-sm font-medium" for="controlnet_end">ControlNet Guidance End</label>
         
     | 
| 341 | 
         
            +
                                <input type="range" id="controlnet_end" name="controlnet_end" min="0.0" max="1.0" step="0.001"
         
     | 
| 342 | 
         
            +
                                    value="1.0" oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
         
     | 
| 343 | 
         
            +
                                <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
         
     | 
| 344 | 
         
            +
                                    1.0</output>
         
     | 
| 345 | 
         
            +
                                <!--  -->
         
     | 
| 346 | 
         
            +
                                <label class="text-sm font-medium" for="canny_low_threshold">Canny Low Threshold</label>
         
     | 
| 347 | 
         
            +
                                <input type="range" id="canny_low_threshold" name="canny_low_threshold" min="0.0" max="1.0"
         
     | 
| 348 | 
         
            +
                                    step="0.001" value="0.2"
         
     | 
| 349 | 
         
            +
                                    oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
         
     | 
| 350 | 
         
            +
                                <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
         
     | 
| 351 | 
         
            +
                                    0.2</output>
         
     | 
| 352 | 
         
            +
                                <!--  -->
         
     | 
| 353 | 
         
            +
                                <label class="text-sm font-medium" for="canny_high_threshold">Canny High Threshold</label>
         
     | 
| 354 | 
         
            +
                                <input type="range" id="canny_high_threshold" name="canny_high_threshold" min="0.0" max="1.0"
         
     | 
| 355 | 
         
            +
                                    step="0.001" value="0.8"
         
     | 
| 356 | 
         
            +
                                    oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
         
     | 
| 357 | 
         
            +
                                <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
         
     | 
| 358 | 
         
            +
                                    0.8</output>
         
     | 
| 359 | 
         
            +
                                <!--  -->
         
     | 
| 360 | 
         
            +
                                <label class="text-sm font-medium" for="seed">Seed</label>
         
     | 
| 361 | 
         
            +
                                <input type="number" id="seed" name="seed" value="299792458"
         
     | 
| 362 | 
         
            +
                                    class="font-light border border-gray-700 text-right rounded-md p-2 dark:text-black">
         
     | 
| 363 | 
         
            +
                                <button
         
     | 
| 364 | 
         
            +
                                    onclick="document.querySelector('#seed').value = Math.floor(Math.random() * Number.MAX_SAFE_INTEGER)"
         
     | 
| 365 | 
         
            +
                                    class="button">
         
     | 
| 366 | 
         
            +
                                    Rand
         
     | 
| 367 | 
         
            +
                                </button>
         
     | 
| 368 | 
         
            +
                                <!--  -->
         
     | 
| 369 | 
         
            +
                                <!--  -->
         
     | 
| 370 | 
         
            +
                                <label class="text-sm font-medium" for="dimension">Image Dimensions</label>
         
     | 
| 371 | 
         
            +
                                <div class="col-span-2 flex gap-2">
         
     | 
| 372 | 
         
            +
                                    <div class="flex gap-1">
         
     | 
| 373 | 
         
            +
                                        <input type="radio" id="dimension512" name="dimension" value="[512,512]" checked
         
     | 
| 374 | 
         
            +
                                            class="cursor-pointer">
         
     | 
| 375 | 
         
            +
                                        <label for="dimension512" class="text-sm cursor-pointer">512x512</label>
         
     | 
| 376 | 
         
            +
                                    </div>
         
     | 
| 377 | 
         
            +
                                    <div class="flex gap-1">
         
     | 
| 378 | 
         
            +
                                        <input type="radio" id="dimension768" name="dimension" value="[768,768]"
         
     | 
| 379 | 
         
            +
                                            lass="cursor-pointer">
         
     | 
| 380 | 
         
            +
                                        <label for="dimension768" class="text-sm cursor-pointer">768x768</label>
         
     | 
| 381 | 
         
            +
                                    </div>
         
     | 
| 382 | 
         
            +
                                </div>
         
     | 
| 383 | 
         
            +
                                <!--  -->
         
     | 
| 384 | 
         
            +
                            </div>
         
     | 
| 385 | 
         
            +
                        </details>
         
     | 
| 386 | 
         
            +
                    </div>
         
     | 
| 387 | 
         
            +
                    <div class="flex gap-3">
         
     | 
| 388 | 
         
            +
                        <button id="start" class="button">
         
     | 
| 389 | 
         
            +
                            Start
         
     | 
| 390 | 
         
            +
                        </button>
         
     | 
| 391 | 
         
            +
                        <button id="stop" class="button">
         
     | 
| 392 | 
         
            +
                            Stop
         
     | 
| 393 | 
         
            +
                        </button>
         
     | 
| 394 | 
         
            +
                        <button id="snap" disabled class="button ml-auto">
         
     | 
| 395 | 
         
            +
                            Snapshot
         
     | 
| 396 | 
         
            +
                        </button>
         
     | 
| 397 | 
         
            +
                    </div>
         
     | 
| 398 | 
         
            +
                    <div class="relative rounded-lg border border-slate-300 overflow-hidden">
         
     | 
| 399 | 
         
            +
                        <img id="player" class="w-full aspect-square rounded-lg"
         
     | 
| 400 | 
         
            +
                            src="data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=">
         
     | 
| 401 | 
         
            +
                        <div class="absolute top-0 left-0 w-1/4 aspect-square">
         
     | 
| 402 | 
         
            +
                            <video id="webcam" class="w-full aspect-square relative z-10 object-cover" playsinline autoplay muted
         
     | 
| 403 | 
         
            +
                                loop></video>
         
     | 
| 404 | 
         
            +
                            <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 448" width="100"
         
     | 
| 405 | 
         
            +
                                class="w-full p-4 absolute top-0 opacity-20 z-0">
         
     | 
| 406 | 
         
            +
                                <path fill="currentColor"
         
     | 
| 407 | 
         
            +
                                    d="M224 256a128 128 0 1 0 0-256 128 128 0 1 0 0 256zm-45.7 48A178.3 178.3 0 0 0 0 482.3 29.7 29.7 0 0 0 29.7 512h388.6a29.7 29.7 0 0 0 29.7-29.7c0-98.5-79.8-178.3-178.3-178.3h-91.4z" />
         
     | 
| 408 | 
         
            +
                            </svg>
         
     | 
| 409 | 
         
            +
                        </div>
         
     | 
| 410 | 
         
            +
                    </div>
         
     | 
| 411 | 
         
            +
                </main>
         
     | 
| 412 | 
         
            +
            </body>
         
     | 
| 413 | 
         
            +
             
     | 
| 414 | 
         
            +
            </html>
         
     | 
    	
        controlnet/tailwind.config.js
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        latent_consistency_controlnet.py
    ADDED
    
    | 
         @@ -0,0 +1,1094 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
         
     | 
| 16 | 
         
            +
            # and https://github.com/hojonathanho/diffusion
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            import math
         
     | 
| 19 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 20 | 
         
            +
            from typing import Any, Dict, List, Optional, Tuple, Union
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            import numpy as np
         
     | 
| 23 | 
         
            +
            import torch
         
     | 
| 24 | 
         
            +
            from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            from diffusers import (
         
     | 
| 27 | 
         
            +
                AutoencoderKL,
         
     | 
| 28 | 
         
            +
                AutoencoderTiny,
         
     | 
| 29 | 
         
            +
                ConfigMixin,
         
     | 
| 30 | 
         
            +
                DiffusionPipeline,
         
     | 
| 31 | 
         
            +
                SchedulerMixin,
         
     | 
| 32 | 
         
            +
                UNet2DConditionModel,
         
     | 
| 33 | 
         
            +
                ControlNetModel,
         
     | 
| 34 | 
         
            +
                logging,
         
     | 
| 35 | 
         
            +
            )
         
     | 
| 36 | 
         
            +
            from diffusers.configuration_utils import register_to_config
         
     | 
| 37 | 
         
            +
            from diffusers.image_processor import VaeImageProcessor, PipelineImageInput
         
     | 
| 38 | 
         
            +
            from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
         
     | 
| 39 | 
         
            +
            from diffusers.pipelines.stable_diffusion.safety_checker import (
         
     | 
| 40 | 
         
            +
                StableDiffusionSafetyChecker,
         
     | 
| 41 | 
         
            +
            )
         
     | 
| 42 | 
         
            +
            from diffusers.utils import BaseOutput
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            from diffusers.utils.torch_utils import randn_tensor, is_compiled_module
         
     | 
| 45 | 
         
            +
            from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            import PIL.Image
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
            class LatentConsistencyModelPipeline_controlnet(DiffusionPipeline):
         
     | 
| 54 | 
         
            +
                _optional_components = ["scheduler"]
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                def __init__(
         
     | 
| 57 | 
         
            +
                    self,
         
     | 
| 58 | 
         
            +
                    vae: AutoencoderKL,
         
     | 
| 59 | 
         
            +
                    text_encoder: CLIPTextModel,
         
     | 
| 60 | 
         
            +
                    tokenizer: CLIPTokenizer,
         
     | 
| 61 | 
         
            +
                    controlnet: Union[
         
     | 
| 62 | 
         
            +
                        ControlNetModel,
         
     | 
| 63 | 
         
            +
                        List[ControlNetModel],
         
     | 
| 64 | 
         
            +
                        Tuple[ControlNetModel],
         
     | 
| 65 | 
         
            +
                        MultiControlNetModel,
         
     | 
| 66 | 
         
            +
                    ],
         
     | 
| 67 | 
         
            +
                    unet: UNet2DConditionModel,
         
     | 
| 68 | 
         
            +
                    scheduler: "LCMScheduler",
         
     | 
| 69 | 
         
            +
                    safety_checker: StableDiffusionSafetyChecker,
         
     | 
| 70 | 
         
            +
                    feature_extractor: CLIPImageProcessor,
         
     | 
| 71 | 
         
            +
                    requires_safety_checker: bool = True,
         
     | 
| 72 | 
         
            +
                ):
         
     | 
| 73 | 
         
            +
                    super().__init__()
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                    scheduler = (
         
     | 
| 76 | 
         
            +
                        scheduler
         
     | 
| 77 | 
         
            +
                        if scheduler is not None
         
     | 
| 78 | 
         
            +
                        else LCMScheduler_X(
         
     | 
| 79 | 
         
            +
                            beta_start=0.00085,
         
     | 
| 80 | 
         
            +
                            beta_end=0.0120,
         
     | 
| 81 | 
         
            +
                            beta_schedule="scaled_linear",
         
     | 
| 82 | 
         
            +
                            prediction_type="epsilon",
         
     | 
| 83 | 
         
            +
                        )
         
     | 
| 84 | 
         
            +
                    )
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                    self.register_modules(
         
     | 
| 87 | 
         
            +
                        vae=vae,
         
     | 
| 88 | 
         
            +
                        text_encoder=text_encoder,
         
     | 
| 89 | 
         
            +
                        tokenizer=tokenizer,
         
     | 
| 90 | 
         
            +
                        unet=unet,
         
     | 
| 91 | 
         
            +
                        controlnet=controlnet,
         
     | 
| 92 | 
         
            +
                        scheduler=scheduler,
         
     | 
| 93 | 
         
            +
                        safety_checker=safety_checker,
         
     | 
| 94 | 
         
            +
                        feature_extractor=feature_extractor,
         
     | 
| 95 | 
         
            +
                    )
         
     | 
| 96 | 
         
            +
                    self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
         
     | 
| 97 | 
         
            +
                    self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
         
     | 
| 98 | 
         
            +
                    self.control_image_processor = VaeImageProcessor(
         
     | 
| 99 | 
         
            +
                        vae_scale_factor=self.vae_scale_factor,
         
     | 
| 100 | 
         
            +
                        do_convert_rgb=True,
         
     | 
| 101 | 
         
            +
                        do_normalize=False,
         
     | 
| 102 | 
         
            +
                    )
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                def _encode_prompt(
         
     | 
| 105 | 
         
            +
                    self,
         
     | 
| 106 | 
         
            +
                    prompt,
         
     | 
| 107 | 
         
            +
                    device,
         
     | 
| 108 | 
         
            +
                    num_images_per_prompt,
         
     | 
| 109 | 
         
            +
                    prompt_embeds: None,
         
     | 
| 110 | 
         
            +
                ):
         
     | 
| 111 | 
         
            +
                    r"""
         
     | 
| 112 | 
         
            +
                    Encodes the prompt into text encoder hidden states.
         
     | 
| 113 | 
         
            +
                    Args:
         
     | 
| 114 | 
         
            +
                        prompt (`str` or `List[str]`, *optional*):
         
     | 
| 115 | 
         
            +
                            prompt to be encoded
         
     | 
| 116 | 
         
            +
                        device: (`torch.device`):
         
     | 
| 117 | 
         
            +
                            torch device
         
     | 
| 118 | 
         
            +
                        num_images_per_prompt (`int`):
         
     | 
| 119 | 
         
            +
                            number of images that should be generated per prompt
         
     | 
| 120 | 
         
            +
                        prompt_embeds (`torch.FloatTensor`, *optional*):
         
     | 
| 121 | 
         
            +
                            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
         
     | 
| 122 | 
         
            +
                            provided, text embeddings will be generated from `prompt` input argument.
         
     | 
| 123 | 
         
            +
                    """
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    if prompt is not None and isinstance(prompt, str):
         
     | 
| 126 | 
         
            +
                        pass
         
     | 
| 127 | 
         
            +
                    elif prompt is not None and isinstance(prompt, list):
         
     | 
| 128 | 
         
            +
                        len(prompt)
         
     | 
| 129 | 
         
            +
                    else:
         
     | 
| 130 | 
         
            +
                        prompt_embeds.shape[0]
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    if prompt_embeds is None:
         
     | 
| 133 | 
         
            +
                        text_inputs = self.tokenizer(
         
     | 
| 134 | 
         
            +
                            prompt,
         
     | 
| 135 | 
         
            +
                            padding="max_length",
         
     | 
| 136 | 
         
            +
                            max_length=self.tokenizer.model_max_length,
         
     | 
| 137 | 
         
            +
                            truncation=True,
         
     | 
| 138 | 
         
            +
                            return_tensors="pt",
         
     | 
| 139 | 
         
            +
                        )
         
     | 
| 140 | 
         
            +
                        text_input_ids = text_inputs.input_ids
         
     | 
| 141 | 
         
            +
                        untruncated_ids = self.tokenizer(
         
     | 
| 142 | 
         
            +
                            prompt, padding="longest", return_tensors="pt"
         
     | 
| 143 | 
         
            +
                        ).input_ids
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                        if untruncated_ids.shape[-1] >= text_input_ids.shape[
         
     | 
| 146 | 
         
            +
                            -1
         
     | 
| 147 | 
         
            +
                        ] and not torch.equal(text_input_ids, untruncated_ids):
         
     | 
| 148 | 
         
            +
                            removed_text = self.tokenizer.batch_decode(
         
     | 
| 149 | 
         
            +
                                untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
         
     | 
| 150 | 
         
            +
                            )
         
     | 
| 151 | 
         
            +
                            logger.warning(
         
     | 
| 152 | 
         
            +
                                "The following part of your input was truncated because CLIP can only handle sequences up to"
         
     | 
| 153 | 
         
            +
                                f" {self.tokenizer.model_max_length} tokens: {removed_text}"
         
     | 
| 154 | 
         
            +
                            )
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                        if (
         
     | 
| 157 | 
         
            +
                            hasattr(self.text_encoder.config, "use_attention_mask")
         
     | 
| 158 | 
         
            +
                            and self.text_encoder.config.use_attention_mask
         
     | 
| 159 | 
         
            +
                        ):
         
     | 
| 160 | 
         
            +
                            attention_mask = text_inputs.attention_mask.to(device)
         
     | 
| 161 | 
         
            +
                        else:
         
     | 
| 162 | 
         
            +
                            attention_mask = None
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                        prompt_embeds = self.text_encoder(
         
     | 
| 165 | 
         
            +
                            text_input_ids.to(device),
         
     | 
| 166 | 
         
            +
                            attention_mask=attention_mask,
         
     | 
| 167 | 
         
            +
                        )
         
     | 
| 168 | 
         
            +
                        prompt_embeds = prompt_embeds[0]
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                    if self.text_encoder is not None:
         
     | 
| 171 | 
         
            +
                        prompt_embeds_dtype = self.text_encoder.dtype
         
     | 
| 172 | 
         
            +
                    elif self.unet is not None:
         
     | 
| 173 | 
         
            +
                        prompt_embeds_dtype = self.unet.dtype
         
     | 
| 174 | 
         
            +
                    else:
         
     | 
| 175 | 
         
            +
                        prompt_embeds_dtype = prompt_embeds.dtype
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                    prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                    bs_embed, seq_len, _ = prompt_embeds.shape
         
     | 
| 180 | 
         
            +
                    # duplicate text embeddings for each generation per prompt, using mps friendly method
         
     | 
| 181 | 
         
            +
                    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
         
     | 
| 182 | 
         
            +
                    prompt_embeds = prompt_embeds.view(
         
     | 
| 183 | 
         
            +
                        bs_embed * num_images_per_prompt, seq_len, -1
         
     | 
| 184 | 
         
            +
                    )
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                    # Don't need to get uncond prompt embedding because of LCM Guided Distillation
         
     | 
| 187 | 
         
            +
                    return prompt_embeds
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                def run_safety_checker(self, image, device, dtype):
         
     | 
| 190 | 
         
            +
                    if self.safety_checker is None:
         
     | 
| 191 | 
         
            +
                        has_nsfw_concept = None
         
     | 
| 192 | 
         
            +
                    else:
         
     | 
| 193 | 
         
            +
                        if torch.is_tensor(image):
         
     | 
| 194 | 
         
            +
                            feature_extractor_input = self.image_processor.postprocess(
         
     | 
| 195 | 
         
            +
                                image, output_type="pil"
         
     | 
| 196 | 
         
            +
                            )
         
     | 
| 197 | 
         
            +
                        else:
         
     | 
| 198 | 
         
            +
                            feature_extractor_input = self.image_processor.numpy_to_pil(image)
         
     | 
| 199 | 
         
            +
                        safety_checker_input = self.feature_extractor(
         
     | 
| 200 | 
         
            +
                            feature_extractor_input, return_tensors="pt"
         
     | 
| 201 | 
         
            +
                        ).to(device)
         
     | 
| 202 | 
         
            +
                        image, has_nsfw_concept = self.safety_checker(
         
     | 
| 203 | 
         
            +
                            images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
         
     | 
| 204 | 
         
            +
                        )
         
     | 
| 205 | 
         
            +
                    return image, has_nsfw_concept
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                def prepare_control_image(
         
     | 
| 208 | 
         
            +
                    self,
         
     | 
| 209 | 
         
            +
                    image,
         
     | 
| 210 | 
         
            +
                    width,
         
     | 
| 211 | 
         
            +
                    height,
         
     | 
| 212 | 
         
            +
                    batch_size,
         
     | 
| 213 | 
         
            +
                    num_images_per_prompt,
         
     | 
| 214 | 
         
            +
                    device,
         
     | 
| 215 | 
         
            +
                    dtype,
         
     | 
| 216 | 
         
            +
                    do_classifier_free_guidance=False,
         
     | 
| 217 | 
         
            +
                    guess_mode=False,
         
     | 
| 218 | 
         
            +
                ):
         
     | 
| 219 | 
         
            +
                    image = self.control_image_processor.preprocess(
         
     | 
| 220 | 
         
            +
                        image, height=height, width=width
         
     | 
| 221 | 
         
            +
                    ).to(dtype=dtype)
         
     | 
| 222 | 
         
            +
                    image_batch_size = image.shape[0]
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
                    if image_batch_size == 1:
         
     | 
| 225 | 
         
            +
                        repeat_by = batch_size
         
     | 
| 226 | 
         
            +
                    else:
         
     | 
| 227 | 
         
            +
                        # image batch size is the same as prompt batch size
         
     | 
| 228 | 
         
            +
                        repeat_by = num_images_per_prompt
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                    image = image.repeat_interleave(repeat_by, dim=0)
         
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
                    image = image.to(device=device, dtype=dtype)
         
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
                    if do_classifier_free_guidance and not guess_mode:
         
     | 
| 235 | 
         
            +
                        image = torch.cat([image] * 2)
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
                    return image
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
                def prepare_latents(
         
     | 
| 240 | 
         
            +
                    self,
         
     | 
| 241 | 
         
            +
                    image,
         
     | 
| 242 | 
         
            +
                    timestep,
         
     | 
| 243 | 
         
            +
                    batch_size,
         
     | 
| 244 | 
         
            +
                    num_channels_latents,
         
     | 
| 245 | 
         
            +
                    height,
         
     | 
| 246 | 
         
            +
                    width,
         
     | 
| 247 | 
         
            +
                    dtype,
         
     | 
| 248 | 
         
            +
                    device,
         
     | 
| 249 | 
         
            +
                    latents=None,
         
     | 
| 250 | 
         
            +
                    generator=None,
         
     | 
| 251 | 
         
            +
                ):
         
     | 
| 252 | 
         
            +
                    shape = (
         
     | 
| 253 | 
         
            +
                        batch_size,
         
     | 
| 254 | 
         
            +
                        num_channels_latents,
         
     | 
| 255 | 
         
            +
                        height // self.vae_scale_factor,
         
     | 
| 256 | 
         
            +
                        width // self.vae_scale_factor,
         
     | 
| 257 | 
         
            +
                    )
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
                    if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
         
     | 
| 260 | 
         
            +
                        raise ValueError(
         
     | 
| 261 | 
         
            +
                            f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
         
     | 
| 262 | 
         
            +
                        )
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                    image = image.to(device=device, dtype=dtype)
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                    # batch_size = batch_size * num_images_per_prompt
         
     | 
| 267 | 
         
            +
             
     | 
| 268 | 
         
            +
                    if image.shape[1] == 4:
         
     | 
| 269 | 
         
            +
                        init_latents = image
         
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
                    else:
         
     | 
| 272 | 
         
            +
                        if isinstance(generator, list) and len(generator) != batch_size:
         
     | 
| 273 | 
         
            +
                            raise ValueError(
         
     | 
| 274 | 
         
            +
                                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
         
     | 
| 275 | 
         
            +
                                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
         
     | 
| 276 | 
         
            +
                            )
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
                        elif isinstance(generator, list):
         
     | 
| 279 | 
         
            +
                            if isinstance(self.vae, AutoencoderTiny):
         
     | 
| 280 | 
         
            +
                                init_latents = [
         
     | 
| 281 | 
         
            +
                                    self.vae.encode(image[i : i + 1]).latents
         
     | 
| 282 | 
         
            +
                                    for i in range(batch_size)
         
     | 
| 283 | 
         
            +
                                ]
         
     | 
| 284 | 
         
            +
                            else:
         
     | 
| 285 | 
         
            +
                                init_latents = [
         
     | 
| 286 | 
         
            +
                                    self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i])
         
     | 
| 287 | 
         
            +
                                    for i in range(batch_size)
         
     | 
| 288 | 
         
            +
                                ]
         
     | 
| 289 | 
         
            +
                            init_latents = torch.cat(init_latents, dim=0)
         
     | 
| 290 | 
         
            +
                        else:
         
     | 
| 291 | 
         
            +
                            if isinstance(self.vae, AutoencoderTiny):
         
     | 
| 292 | 
         
            +
                                init_latents = self.vae.encode(image).latents
         
     | 
| 293 | 
         
            +
                            else:
         
     | 
| 294 | 
         
            +
                                init_latents = self.vae.encode(image).latent_dist.sample(generator)
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
                        init_latents = self.vae.config.scaling_factor * init_latents
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
                    if (
         
     | 
| 299 | 
         
            +
                        batch_size > init_latents.shape[0]
         
     | 
| 300 | 
         
            +
                        and batch_size % init_latents.shape[0] == 0
         
     | 
| 301 | 
         
            +
                    ):
         
     | 
| 302 | 
         
            +
                        # expand init_latents for batch_size
         
     | 
| 303 | 
         
            +
                        deprecation_message = (
         
     | 
| 304 | 
         
            +
                            f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
         
     | 
| 305 | 
         
            +
                            " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
         
     | 
| 306 | 
         
            +
                            " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
         
     | 
| 307 | 
         
            +
                            " your script to pass as many initial images as text prompts to suppress this warning."
         
     | 
| 308 | 
         
            +
                        )
         
     | 
| 309 | 
         
            +
                        # deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
         
     | 
| 310 | 
         
            +
                        additional_image_per_prompt = batch_size // init_latents.shape[0]
         
     | 
| 311 | 
         
            +
                        init_latents = torch.cat(
         
     | 
| 312 | 
         
            +
                            [init_latents] * additional_image_per_prompt, dim=0
         
     | 
| 313 | 
         
            +
                        )
         
     | 
| 314 | 
         
            +
                    elif (
         
     | 
| 315 | 
         
            +
                        batch_size > init_latents.shape[0]
         
     | 
| 316 | 
         
            +
                        and batch_size % init_latents.shape[0] != 0
         
     | 
| 317 | 
         
            +
                    ):
         
     | 
| 318 | 
         
            +
                        raise ValueError(
         
     | 
| 319 | 
         
            +
                            f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
         
     | 
| 320 | 
         
            +
                        )
         
     | 
| 321 | 
         
            +
                    else:
         
     | 
| 322 | 
         
            +
                        init_latents = torch.cat([init_latents], dim=0)
         
     | 
| 323 | 
         
            +
             
     | 
| 324 | 
         
            +
                    shape = init_latents.shape
         
     | 
| 325 | 
         
            +
                    noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
         
     | 
| 326 | 
         
            +
             
     | 
| 327 | 
         
            +
                    # get latents
         
     | 
| 328 | 
         
            +
                    init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
         
     | 
| 329 | 
         
            +
                    latents = init_latents
         
     | 
| 330 | 
         
            +
             
     | 
| 331 | 
         
            +
                    return latents
         
     | 
| 332 | 
         
            +
             
     | 
| 333 | 
         
            +
                    if latents is None:
         
     | 
| 334 | 
         
            +
                        latents = torch.randn(shape, dtype=dtype).to(device)
         
     | 
| 335 | 
         
            +
                    else:
         
     | 
| 336 | 
         
            +
                        latents = latents.to(device)
         
     | 
| 337 | 
         
            +
                    # scale the initial noise by the standard deviation required by the scheduler
         
     | 
| 338 | 
         
            +
                    latents = latents * self.scheduler.init_noise_sigma
         
     | 
| 339 | 
         
            +
                    return latents
         
     | 
| 340 | 
         
            +
             
     | 
| 341 | 
         
            +
                def get_w_embedding(self, w, embedding_dim=512, dtype=torch.float32):
         
     | 
| 342 | 
         
            +
                    """
         
     | 
| 343 | 
         
            +
                    see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
         
     | 
| 344 | 
         
            +
                    Args:
         
     | 
| 345 | 
         
            +
                    timesteps: torch.Tensor: generate embedding vectors at these timesteps
         
     | 
| 346 | 
         
            +
                    embedding_dim: int: dimension of the embeddings to generate
         
     | 
| 347 | 
         
            +
                    dtype: data type of the generated embeddings
         
     | 
| 348 | 
         
            +
                    Returns:
         
     | 
| 349 | 
         
            +
                    embedding vectors with shape `(len(timesteps), embedding_dim)`
         
     | 
| 350 | 
         
            +
                    """
         
     | 
| 351 | 
         
            +
                    assert len(w.shape) == 1
         
     | 
| 352 | 
         
            +
                    w = w * 1000.0
         
     | 
| 353 | 
         
            +
             
     | 
| 354 | 
         
            +
                    half_dim = embedding_dim // 2
         
     | 
| 355 | 
         
            +
                    emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
         
     | 
| 356 | 
         
            +
                    emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
         
     | 
| 357 | 
         
            +
                    emb = w.to(dtype)[:, None] * emb[None, :]
         
     | 
| 358 | 
         
            +
                    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
         
     | 
| 359 | 
         
            +
                    if embedding_dim % 2 == 1:  # zero pad
         
     | 
| 360 | 
         
            +
                        emb = torch.nn.functional.pad(emb, (0, 1))
         
     | 
| 361 | 
         
            +
                    assert emb.shape == (w.shape[0], embedding_dim)
         
     | 
| 362 | 
         
            +
                    return emb
         
     | 
| 363 | 
         
            +
             
     | 
| 364 | 
         
            +
                def get_timesteps(self, num_inference_steps, strength, device):
         
     | 
| 365 | 
         
            +
                    # get the original timestep using init_timestep
         
     | 
| 366 | 
         
            +
                    init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
         
     | 
| 367 | 
         
            +
             
     | 
| 368 | 
         
            +
                    t_start = max(num_inference_steps - init_timestep, 0)
         
     | 
| 369 | 
         
            +
                    timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
         
     | 
| 370 | 
         
            +
             
     | 
| 371 | 
         
            +
                    return timesteps, num_inference_steps - t_start
         
     | 
| 372 | 
         
            +
             
     | 
| 373 | 
         
            +
                @torch.no_grad()
         
     | 
| 374 | 
         
            +
                def __call__(
         
     | 
| 375 | 
         
            +
                    self,
         
     | 
| 376 | 
         
            +
                    prompt: Union[str, List[str]] = None,
         
     | 
| 377 | 
         
            +
                    image: PipelineImageInput = None,
         
     | 
| 378 | 
         
            +
                    control_image: PipelineImageInput = None,
         
     | 
| 379 | 
         
            +
                    strength: float = 0.8,
         
     | 
| 380 | 
         
            +
                    height: Optional[int] = 768,
         
     | 
| 381 | 
         
            +
                    width: Optional[int] = 768,
         
     | 
| 382 | 
         
            +
                    guidance_scale: float = 7.5,
         
     | 
| 383 | 
         
            +
                    num_images_per_prompt: Optional[int] = 1,
         
     | 
| 384 | 
         
            +
                    latents: Optional[torch.FloatTensor] = None,
         
     | 
| 385 | 
         
            +
                    generator: Optional[torch.Generator] = None,
         
     | 
| 386 | 
         
            +
                    num_inference_steps: int = 4,
         
     | 
| 387 | 
         
            +
                    lcm_origin_steps: int = 50,
         
     | 
| 388 | 
         
            +
                    prompt_embeds: Optional[torch.FloatTensor] = None,
         
     | 
| 389 | 
         
            +
                    output_type: Optional[str] = "pil",
         
     | 
| 390 | 
         
            +
                    return_dict: bool = True,
         
     | 
| 391 | 
         
            +
                    cross_attention_kwargs: Optional[Dict[str, Any]] = None,
         
     | 
| 392 | 
         
            +
                    controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
         
     | 
| 393 | 
         
            +
                    guess_mode: bool = True,
         
     | 
| 394 | 
         
            +
                    control_guidance_start: Union[float, List[float]] = 0.0,
         
     | 
| 395 | 
         
            +
                    control_guidance_end: Union[float, List[float]] = 1.0,
         
     | 
| 396 | 
         
            +
                ):
         
     | 
| 397 | 
         
            +
                    controlnet = (
         
     | 
| 398 | 
         
            +
                        self.controlnet._orig_mod
         
     | 
| 399 | 
         
            +
                        if is_compiled_module(self.controlnet)
         
     | 
| 400 | 
         
            +
                        else self.controlnet
         
     | 
| 401 | 
         
            +
                    )
         
     | 
| 402 | 
         
            +
                    # 0. Default height and width to unet
         
     | 
| 403 | 
         
            +
                    height = height or self.unet.config.sample_size * self.vae_scale_factor
         
     | 
| 404 | 
         
            +
                    width = width or self.unet.config.sample_size * self.vae_scale_factor
         
     | 
| 405 | 
         
            +
                    if not isinstance(control_guidance_start, list) and isinstance(
         
     | 
| 406 | 
         
            +
                        control_guidance_end, list
         
     | 
| 407 | 
         
            +
                    ):
         
     | 
| 408 | 
         
            +
                        control_guidance_start = len(control_guidance_end) * [
         
     | 
| 409 | 
         
            +
                            control_guidance_start
         
     | 
| 410 | 
         
            +
                        ]
         
     | 
| 411 | 
         
            +
                    elif not isinstance(control_guidance_end, list) and isinstance(
         
     | 
| 412 | 
         
            +
                        control_guidance_start, list
         
     | 
| 413 | 
         
            +
                    ):
         
     | 
| 414 | 
         
            +
                        control_guidance_end = len(control_guidance_start) * [control_guidance_end]
         
     | 
| 415 | 
         
            +
                    elif not isinstance(control_guidance_start, list) and not isinstance(
         
     | 
| 416 | 
         
            +
                        control_guidance_end, list
         
     | 
| 417 | 
         
            +
                    ):
         
     | 
| 418 | 
         
            +
                        mult = (
         
     | 
| 419 | 
         
            +
                            len(controlnet.nets)
         
     | 
| 420 | 
         
            +
                            if isinstance(controlnet, MultiControlNetModel)
         
     | 
| 421 | 
         
            +
                            else 1
         
     | 
| 422 | 
         
            +
                        )
         
     | 
| 423 | 
         
            +
                        control_guidance_start, control_guidance_end = mult * [
         
     | 
| 424 | 
         
            +
                            control_guidance_start
         
     | 
| 425 | 
         
            +
                        ], mult * [control_guidance_end]
         
     | 
| 426 | 
         
            +
                    # 2. Define call parameters
         
     | 
| 427 | 
         
            +
                    if prompt is not None and isinstance(prompt, str):
         
     | 
| 428 | 
         
            +
                        batch_size = 1
         
     | 
| 429 | 
         
            +
                    elif prompt is not None and isinstance(prompt, list):
         
     | 
| 430 | 
         
            +
                        batch_size = len(prompt)
         
     | 
| 431 | 
         
            +
                    else:
         
     | 
| 432 | 
         
            +
                        batch_size = prompt_embeds.shape[0]
         
     | 
| 433 | 
         
            +
             
     | 
| 434 | 
         
            +
                    device = self._execution_device
         
     | 
| 435 | 
         
            +
                    # do_classifier_free_guidance = guidance_scale > 0.0  # In LCM Implementation:  cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond) , (cfg_scale > 0.0 using CFG)
         
     | 
| 436 | 
         
            +
                    global_pool_conditions = (
         
     | 
| 437 | 
         
            +
                        controlnet.config.global_pool_conditions
         
     | 
| 438 | 
         
            +
                        if isinstance(controlnet, ControlNetModel)
         
     | 
| 439 | 
         
            +
                        else controlnet.nets[0].config.global_pool_conditions
         
     | 
| 440 | 
         
            +
                    )
         
     | 
| 441 | 
         
            +
                    guess_mode = guess_mode or global_pool_conditions
         
     | 
| 442 | 
         
            +
                    # 3. Encode input prompt
         
     | 
| 443 | 
         
            +
                    prompt_embeds = self._encode_prompt(
         
     | 
| 444 | 
         
            +
                        prompt,
         
     | 
| 445 | 
         
            +
                        device,
         
     | 
| 446 | 
         
            +
                        num_images_per_prompt,
         
     | 
| 447 | 
         
            +
                        prompt_embeds=prompt_embeds,
         
     | 
| 448 | 
         
            +
                    )
         
     | 
| 449 | 
         
            +
             
     | 
| 450 | 
         
            +
                    # 3.5 encode image
         
     | 
| 451 | 
         
            +
                    image = self.image_processor.preprocess(image)
         
     | 
| 452 | 
         
            +
             
     | 
| 453 | 
         
            +
                    if isinstance(controlnet, ControlNetModel):
         
     | 
| 454 | 
         
            +
                        control_image = self.prepare_control_image(
         
     | 
| 455 | 
         
            +
                            image=control_image,
         
     | 
| 456 | 
         
            +
                            width=width,
         
     | 
| 457 | 
         
            +
                            height=height,
         
     | 
| 458 | 
         
            +
                            batch_size=batch_size * num_images_per_prompt,
         
     | 
| 459 | 
         
            +
                            num_images_per_prompt=num_images_per_prompt,
         
     | 
| 460 | 
         
            +
                            device=device,
         
     | 
| 461 | 
         
            +
                            dtype=controlnet.dtype,
         
     | 
| 462 | 
         
            +
                            guess_mode=guess_mode,
         
     | 
| 463 | 
         
            +
                        )
         
     | 
| 464 | 
         
            +
                    elif isinstance(controlnet, MultiControlNetModel):
         
     | 
| 465 | 
         
            +
                        control_images = []
         
     | 
| 466 | 
         
            +
             
     | 
| 467 | 
         
            +
                        for control_image_ in control_image:
         
     | 
| 468 | 
         
            +
                            control_image_ = self.prepare_control_image(
         
     | 
| 469 | 
         
            +
                                image=control_image_,
         
     | 
| 470 | 
         
            +
                                width=width,
         
     | 
| 471 | 
         
            +
                                height=height,
         
     | 
| 472 | 
         
            +
                                batch_size=batch_size * num_images_per_prompt,
         
     | 
| 473 | 
         
            +
                                num_images_per_prompt=num_images_per_prompt,
         
     | 
| 474 | 
         
            +
                                device=device,
         
     | 
| 475 | 
         
            +
                                dtype=controlnet.dtype,
         
     | 
| 476 | 
         
            +
                                do_classifier_free_guidance=do_classifier_free_guidance,
         
     | 
| 477 | 
         
            +
                                guess_mode=guess_mode,
         
     | 
| 478 | 
         
            +
                            )
         
     | 
| 479 | 
         
            +
             
     | 
| 480 | 
         
            +
                            control_images.append(control_image_)
         
     | 
| 481 | 
         
            +
             
     | 
| 482 | 
         
            +
                        control_image = control_images
         
     | 
| 483 | 
         
            +
                    else:
         
     | 
| 484 | 
         
            +
                        assert False
         
     | 
| 485 | 
         
            +
             
     | 
| 486 | 
         
            +
                    # 4. Prepare timesteps
         
     | 
| 487 | 
         
            +
                    self.scheduler.set_timesteps(strength, num_inference_steps, lcm_origin_steps)
         
     | 
| 488 | 
         
            +
                    # timesteps = self.scheduler.timesteps
         
     | 
| 489 | 
         
            +
                    # timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, 1.0, device)
         
     | 
| 490 | 
         
            +
                    timesteps = self.scheduler.timesteps
         
     | 
| 491 | 
         
            +
                    latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
         
     | 
| 492 | 
         
            +
             
     | 
| 493 | 
         
            +
                    print("timesteps: ", timesteps)
         
     | 
| 494 | 
         
            +
             
     | 
| 495 | 
         
            +
                    # 5. Prepare latent variable
         
     | 
| 496 | 
         
            +
                    num_channels_latents = self.unet.config.in_channels
         
     | 
| 497 | 
         
            +
                    latents = self.prepare_latents(
         
     | 
| 498 | 
         
            +
                        image,
         
     | 
| 499 | 
         
            +
                        latent_timestep,
         
     | 
| 500 | 
         
            +
                        batch_size * num_images_per_prompt,
         
     | 
| 501 | 
         
            +
                        num_channels_latents,
         
     | 
| 502 | 
         
            +
                        height,
         
     | 
| 503 | 
         
            +
                        width,
         
     | 
| 504 | 
         
            +
                        prompt_embeds.dtype,
         
     | 
| 505 | 
         
            +
                        device,
         
     | 
| 506 | 
         
            +
                        latents,
         
     | 
| 507 | 
         
            +
                    )
         
     | 
| 508 | 
         
            +
                    bs = batch_size * num_images_per_prompt
         
     | 
| 509 | 
         
            +
             
     | 
| 510 | 
         
            +
                    # 6. Get Guidance Scale Embedding
         
     | 
| 511 | 
         
            +
                    w = torch.tensor(guidance_scale).repeat(bs)
         
     | 
| 512 | 
         
            +
                    w_embedding = self.get_w_embedding(w, embedding_dim=256).to(
         
     | 
| 513 | 
         
            +
                        device=device, dtype=latents.dtype
         
     | 
| 514 | 
         
            +
                    )
         
     | 
| 515 | 
         
            +
                    controlnet_keep = []
         
     | 
| 516 | 
         
            +
                    for i in range(len(timesteps)):
         
     | 
| 517 | 
         
            +
                        keeps = [
         
     | 
| 518 | 
         
            +
                            1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
         
     | 
| 519 | 
         
            +
                            for s, e in zip(control_guidance_start, control_guidance_end)
         
     | 
| 520 | 
         
            +
                        ]
         
     | 
| 521 | 
         
            +
                        controlnet_keep.append(
         
     | 
| 522 | 
         
            +
                            keeps[0] if isinstance(controlnet, ControlNetModel) else keeps
         
     | 
| 523 | 
         
            +
                        )
         
     | 
| 524 | 
         
            +
                    # 7. LCM MultiStep Sampling Loop:
         
     | 
| 525 | 
         
            +
                    with self.progress_bar(total=num_inference_steps) as progress_bar:
         
     | 
| 526 | 
         
            +
                        for i, t in enumerate(timesteps):
         
     | 
| 527 | 
         
            +
                            ts = torch.full((bs,), t, device=device, dtype=torch.long)
         
     | 
| 528 | 
         
            +
                            latents = latents.to(prompt_embeds.dtype)
         
     | 
| 529 | 
         
            +
                            if guess_mode:
         
     | 
| 530 | 
         
            +
                                # Infer ControlNet only for the conditional batch.
         
     | 
| 531 | 
         
            +
                                control_model_input = latents
         
     | 
| 532 | 
         
            +
                                control_model_input = self.scheduler.scale_model_input(
         
     | 
| 533 | 
         
            +
                                    control_model_input, ts
         
     | 
| 534 | 
         
            +
                                )
         
     | 
| 535 | 
         
            +
                                controlnet_prompt_embeds = prompt_embeds
         
     | 
| 536 | 
         
            +
                            else:
         
     | 
| 537 | 
         
            +
                                control_model_input = latents
         
     | 
| 538 | 
         
            +
                                controlnet_prompt_embeds = prompt_embeds
         
     | 
| 539 | 
         
            +
                            if isinstance(controlnet_keep[i], list):
         
     | 
| 540 | 
         
            +
                                cond_scale = [
         
     | 
| 541 | 
         
            +
                                    c * s
         
     | 
| 542 | 
         
            +
                                    for c, s in zip(
         
     | 
| 543 | 
         
            +
                                        controlnet_conditioning_scale, controlnet_keep[i]
         
     | 
| 544 | 
         
            +
                                    )
         
     | 
| 545 | 
         
            +
                                ]
         
     | 
| 546 | 
         
            +
                            else:
         
     | 
| 547 | 
         
            +
                                controlnet_cond_scale = controlnet_conditioning_scale
         
     | 
| 548 | 
         
            +
                                if isinstance(controlnet_cond_scale, list):
         
     | 
| 549 | 
         
            +
                                    controlnet_cond_scale = controlnet_cond_scale[0]
         
     | 
| 550 | 
         
            +
                                cond_scale = controlnet_cond_scale * controlnet_keep[i]
         
     | 
| 551 | 
         
            +
             
     | 
| 552 | 
         
            +
                            down_block_res_samples, mid_block_res_sample = self.controlnet(
         
     | 
| 553 | 
         
            +
                                control_model_input,
         
     | 
| 554 | 
         
            +
                                ts,
         
     | 
| 555 | 
         
            +
                                encoder_hidden_states=controlnet_prompt_embeds,
         
     | 
| 556 | 
         
            +
                                controlnet_cond=control_image,
         
     | 
| 557 | 
         
            +
                                conditioning_scale=cond_scale,
         
     | 
| 558 | 
         
            +
                                guess_mode=guess_mode,
         
     | 
| 559 | 
         
            +
                                return_dict=False,
         
     | 
| 560 | 
         
            +
                            )
         
     | 
| 561 | 
         
            +
                            # model prediction (v-prediction, eps, x)
         
     | 
| 562 | 
         
            +
                            model_pred = self.unet(
         
     | 
| 563 | 
         
            +
                                latents,
         
     | 
| 564 | 
         
            +
                                ts,
         
     | 
| 565 | 
         
            +
                                timestep_cond=w_embedding,
         
     | 
| 566 | 
         
            +
                                encoder_hidden_states=prompt_embeds,
         
     | 
| 567 | 
         
            +
                                cross_attention_kwargs=cross_attention_kwargs,
         
     | 
| 568 | 
         
            +
                                down_block_additional_residuals=down_block_res_samples,
         
     | 
| 569 | 
         
            +
                                mid_block_additional_residual=mid_block_res_sample,
         
     | 
| 570 | 
         
            +
                                return_dict=False,
         
     | 
| 571 | 
         
            +
                            )[0]
         
     | 
| 572 | 
         
            +
             
     | 
| 573 | 
         
            +
                            # compute the previous noisy sample x_t -> x_t-1
         
     | 
| 574 | 
         
            +
                            latents, denoised = self.scheduler.step(
         
     | 
| 575 | 
         
            +
                                model_pred, i, t, latents, return_dict=False
         
     | 
| 576 | 
         
            +
                            )
         
     | 
| 577 | 
         
            +
             
     | 
| 578 | 
         
            +
                            # # call the callback, if provided
         
     | 
| 579 | 
         
            +
                            # if i == len(timesteps) - 1:
         
     | 
| 580 | 
         
            +
                            progress_bar.update()
         
     | 
| 581 | 
         
            +
             
     | 
| 582 | 
         
            +
                    denoised = denoised.to(prompt_embeds.dtype)
         
     | 
| 583 | 
         
            +
                    if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
         
     | 
| 584 | 
         
            +
                        self.unet.to("cpu")
         
     | 
| 585 | 
         
            +
                        self.controlnet.to("cpu")
         
     | 
| 586 | 
         
            +
                        torch.cuda.empty_cache()
         
     | 
| 587 | 
         
            +
                    if not output_type == "latent":
         
     | 
| 588 | 
         
            +
                        image = self.vae.decode(
         
     | 
| 589 | 
         
            +
                            denoised / self.vae.config.scaling_factor, return_dict=False
         
     | 
| 590 | 
         
            +
                        )[0]
         
     | 
| 591 | 
         
            +
                        image, has_nsfw_concept = self.run_safety_checker(
         
     | 
| 592 | 
         
            +
                            image, device, prompt_embeds.dtype
         
     | 
| 593 | 
         
            +
                        )
         
     | 
| 594 | 
         
            +
                    else:
         
     | 
| 595 | 
         
            +
                        image = denoised
         
     | 
| 596 | 
         
            +
                        has_nsfw_concept = None
         
     | 
| 597 | 
         
            +
             
     | 
| 598 | 
         
            +
                    if has_nsfw_concept is None:
         
     | 
| 599 | 
         
            +
                        do_denormalize = [True] * image.shape[0]
         
     | 
| 600 | 
         
            +
                    else:
         
     | 
| 601 | 
         
            +
                        do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
         
     | 
| 602 | 
         
            +
             
     | 
| 603 | 
         
            +
                    image = self.image_processor.postprocess(
         
     | 
| 604 | 
         
            +
                        image, output_type=output_type, do_denormalize=do_denormalize
         
     | 
| 605 | 
         
            +
                    )
         
     | 
| 606 | 
         
            +
             
     | 
| 607 | 
         
            +
                    if not return_dict:
         
     | 
| 608 | 
         
            +
                        return (image, has_nsfw_concept)
         
     | 
| 609 | 
         
            +
             
     | 
| 610 | 
         
            +
                    return StableDiffusionPipelineOutput(
         
     | 
| 611 | 
         
            +
                        images=image, nsfw_content_detected=has_nsfw_concept
         
     | 
| 612 | 
         
            +
                    )
         
     | 
| 613 | 
         
            +
             
     | 
| 614 | 
         
            +
             
     | 
| 615 | 
         
            +
            @dataclass
         
     | 
| 616 | 
         
            +
            # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
         
     | 
| 617 | 
         
            +
            class LCMSchedulerOutput(BaseOutput):
         
     | 
| 618 | 
         
            +
                """
         
     | 
| 619 | 
         
            +
                Output class for the scheduler's `step` function output.
         
     | 
| 620 | 
         
            +
                Args:
         
     | 
| 621 | 
         
            +
                    prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
         
     | 
| 622 | 
         
            +
                        Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
         
     | 
| 623 | 
         
            +
                        denoising loop.
         
     | 
| 624 | 
         
            +
                    pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
         
     | 
| 625 | 
         
            +
                        The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
         
     | 
| 626 | 
         
            +
                        `pred_original_sample` can be used to preview progress or for guidance.
         
     | 
| 627 | 
         
            +
                """
         
     | 
| 628 | 
         
            +
             
     | 
| 629 | 
         
            +
                prev_sample: torch.FloatTensor
         
     | 
| 630 | 
         
            +
                denoised: Optional[torch.FloatTensor] = None
         
     | 
| 631 | 
         
            +
             
     | 
| 632 | 
         
            +
             
     | 
| 633 | 
         
            +
            # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
         
     | 
| 634 | 
         
            +
            def betas_for_alpha_bar(
         
     | 
| 635 | 
         
            +
                num_diffusion_timesteps,
         
     | 
| 636 | 
         
            +
                max_beta=0.999,
         
     | 
| 637 | 
         
            +
                alpha_transform_type="cosine",
         
     | 
| 638 | 
         
            +
            ):
         
     | 
| 639 | 
         
            +
                """
         
     | 
| 640 | 
         
            +
                Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
         
     | 
| 641 | 
         
            +
                (1-beta) over time from t = [0,1].
         
     | 
| 642 | 
         
            +
                Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
         
     | 
| 643 | 
         
            +
                to that part of the diffusion process.
         
     | 
| 644 | 
         
            +
                Args:
         
     | 
| 645 | 
         
            +
                    num_diffusion_timesteps (`int`): the number of betas to produce.
         
     | 
| 646 | 
         
            +
                    max_beta (`float`): the maximum beta to use; use values lower than 1 to
         
     | 
| 647 | 
         
            +
                                 prevent singularities.
         
     | 
| 648 | 
         
            +
                    alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
         
     | 
| 649 | 
         
            +
                                 Choose from `cosine` or `exp`
         
     | 
| 650 | 
         
            +
                Returns:
         
     | 
| 651 | 
         
            +
                    betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
         
     | 
| 652 | 
         
            +
                """
         
     | 
| 653 | 
         
            +
                if alpha_transform_type == "cosine":
         
     | 
| 654 | 
         
            +
             
     | 
| 655 | 
         
            +
                    def alpha_bar_fn(t):
         
     | 
| 656 | 
         
            +
                        return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
         
     | 
| 657 | 
         
            +
             
     | 
| 658 | 
         
            +
                elif alpha_transform_type == "exp":
         
     | 
| 659 | 
         
            +
             
     | 
| 660 | 
         
            +
                    def alpha_bar_fn(t):
         
     | 
| 661 | 
         
            +
                        return math.exp(t * -12.0)
         
     | 
| 662 | 
         
            +
             
     | 
| 663 | 
         
            +
                else:
         
     | 
| 664 | 
         
            +
                    raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
         
     | 
| 665 | 
         
            +
             
     | 
| 666 | 
         
            +
                betas = []
         
     | 
| 667 | 
         
            +
                for i in range(num_diffusion_timesteps):
         
     | 
| 668 | 
         
            +
                    t1 = i / num_diffusion_timesteps
         
     | 
| 669 | 
         
            +
                    t2 = (i + 1) / num_diffusion_timesteps
         
     | 
| 670 | 
         
            +
                    betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
         
     | 
| 671 | 
         
            +
                return torch.tensor(betas, dtype=torch.float32)
         
     | 
| 672 | 
         
            +
             
     | 
| 673 | 
         
            +
             
     | 
| 674 | 
         
            +
            def rescale_zero_terminal_snr(betas):
         
     | 
| 675 | 
         
            +
                """
         
     | 
| 676 | 
         
            +
                Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
         
     | 
| 677 | 
         
            +
                Args:
         
     | 
| 678 | 
         
            +
                    betas (`torch.FloatTensor`):
         
     | 
| 679 | 
         
            +
                        the betas that the scheduler is being initialized with.
         
     | 
| 680 | 
         
            +
                Returns:
         
     | 
| 681 | 
         
            +
                    `torch.FloatTensor`: rescaled betas with zero terminal SNR
         
     | 
| 682 | 
         
            +
                """
         
     | 
| 683 | 
         
            +
                # Convert betas to alphas_bar_sqrt
         
     | 
| 684 | 
         
            +
                alphas = 1.0 - betas
         
     | 
| 685 | 
         
            +
                alphas_cumprod = torch.cumprod(alphas, dim=0)
         
     | 
| 686 | 
         
            +
                alphas_bar_sqrt = alphas_cumprod.sqrt()
         
     | 
| 687 | 
         
            +
             
     | 
| 688 | 
         
            +
                # Store old values.
         
     | 
| 689 | 
         
            +
                alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
         
     | 
| 690 | 
         
            +
                alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
         
     | 
| 691 | 
         
            +
             
     | 
| 692 | 
         
            +
                # Shift so the last timestep is zero.
         
     | 
| 693 | 
         
            +
                alphas_bar_sqrt -= alphas_bar_sqrt_T
         
     | 
| 694 | 
         
            +
             
     | 
| 695 | 
         
            +
                # Scale so the first timestep is back to the old value.
         
     | 
| 696 | 
         
            +
                alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
         
     | 
| 697 | 
         
            +
             
     | 
| 698 | 
         
            +
                # Convert alphas_bar_sqrt to betas
         
     | 
| 699 | 
         
            +
                alphas_bar = alphas_bar_sqrt**2  # Revert sqrt
         
     | 
| 700 | 
         
            +
                alphas = alphas_bar[1:] / alphas_bar[:-1]  # Revert cumprod
         
     | 
| 701 | 
         
            +
                alphas = torch.cat([alphas_bar[0:1], alphas])
         
     | 
| 702 | 
         
            +
                betas = 1 - alphas
         
     | 
| 703 | 
         
            +
             
     | 
| 704 | 
         
            +
                return betas
         
     | 
| 705 | 
         
            +
             
     | 
| 706 | 
         
            +
             
     | 
| 707 | 
         
            +
            class LCMScheduler_X(SchedulerMixin, ConfigMixin):
         
     | 
| 708 | 
         
            +
                """
         
     | 
| 709 | 
         
            +
                `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
         
     | 
| 710 | 
         
            +
                non-Markovian guidance.
         
     | 
| 711 | 
         
            +
                This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
         
     | 
| 712 | 
         
            +
                methods the library implements for all schedulers such as loading and saving.
         
     | 
| 713 | 
         
            +
                Args:
         
     | 
| 714 | 
         
            +
                    num_train_timesteps (`int`, defaults to 1000):
         
     | 
| 715 | 
         
            +
                        The number of diffusion steps to train the model.
         
     | 
| 716 | 
         
            +
                    beta_start (`float`, defaults to 0.0001):
         
     | 
| 717 | 
         
            +
                        The starting `beta` value of inference.
         
     | 
| 718 | 
         
            +
                    beta_end (`float`, defaults to 0.02):
         
     | 
| 719 | 
         
            +
                        The final `beta` value.
         
     | 
| 720 | 
         
            +
                    beta_schedule (`str`, defaults to `"linear"`):
         
     | 
| 721 | 
         
            +
                        The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
         
     | 
| 722 | 
         
            +
                        `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
         
     | 
| 723 | 
         
            +
                    trained_betas (`np.ndarray`, *optional*):
         
     | 
| 724 | 
         
            +
                        Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
         
     | 
| 725 | 
         
            +
                    clip_sample (`bool`, defaults to `True`):
         
     | 
| 726 | 
         
            +
                        Clip the predicted sample for numerical stability.
         
     | 
| 727 | 
         
            +
                    clip_sample_range (`float`, defaults to 1.0):
         
     | 
| 728 | 
         
            +
                        The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
         
     | 
| 729 | 
         
            +
                    set_alpha_to_one (`bool`, defaults to `True`):
         
     | 
| 730 | 
         
            +
                        Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
         
     | 
| 731 | 
         
            +
                        there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
         
     | 
| 732 | 
         
            +
                        otherwise it uses the alpha value at step 0.
         
     | 
| 733 | 
         
            +
                    steps_offset (`int`, defaults to 0):
         
     | 
| 734 | 
         
            +
                        An offset added to the inference steps. You can use a combination of `offset=1` and
         
     | 
| 735 | 
         
            +
                        `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
         
     | 
| 736 | 
         
            +
                        Diffusion.
         
     | 
| 737 | 
         
            +
                    prediction_type (`str`, defaults to `epsilon`, *optional*):
         
     | 
| 738 | 
         
            +
                        Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
         
     | 
| 739 | 
         
            +
                        `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
         
     | 
| 740 | 
         
            +
                        Video](https://imagen.research.google/video/paper.pdf) paper).
         
     | 
| 741 | 
         
            +
                    thresholding (`bool`, defaults to `False`):
         
     | 
| 742 | 
         
            +
                        Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
         
     | 
| 743 | 
         
            +
                        as Stable Diffusion.
         
     | 
| 744 | 
         
            +
                    dynamic_thresholding_ratio (`float`, defaults to 0.995):
         
     | 
| 745 | 
         
            +
                        The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
         
     | 
| 746 | 
         
            +
                    sample_max_value (`float`, defaults to 1.0):
         
     | 
| 747 | 
         
            +
                        The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
         
     | 
| 748 | 
         
            +
                    timestep_spacing (`str`, defaults to `"leading"`):
         
     | 
| 749 | 
         
            +
                        The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
         
     | 
| 750 | 
         
            +
                        Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
         
     | 
| 751 | 
         
            +
                    rescale_betas_zero_snr (`bool`, defaults to `False`):
         
     | 
| 752 | 
         
            +
                        Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
         
     | 
| 753 | 
         
            +
                        dark samples instead of limiting it to samples with medium brightness. Loosely related to
         
     | 
| 754 | 
         
            +
                        [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
         
     | 
| 755 | 
         
            +
                """
         
     | 
| 756 | 
         
            +
             
     | 
| 757 | 
         
            +
                # _compatibles = [e.name for e in KarrasDiffusionSchedulers]
         
     | 
| 758 | 
         
            +
                order = 1
         
     | 
| 759 | 
         
            +
             
     | 
| 760 | 
         
            +
                @register_to_config
         
     | 
| 761 | 
         
            +
                def __init__(
         
     | 
| 762 | 
         
            +
                    self,
         
     | 
| 763 | 
         
            +
                    num_train_timesteps: int = 1000,
         
     | 
| 764 | 
         
            +
                    beta_start: float = 0.0001,
         
     | 
| 765 | 
         
            +
                    beta_end: float = 0.02,
         
     | 
| 766 | 
         
            +
                    beta_schedule: str = "linear",
         
     | 
| 767 | 
         
            +
                    trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
         
     | 
| 768 | 
         
            +
                    clip_sample: bool = True,
         
     | 
| 769 | 
         
            +
                    set_alpha_to_one: bool = True,
         
     | 
| 770 | 
         
            +
                    steps_offset: int = 0,
         
     | 
| 771 | 
         
            +
                    prediction_type: str = "epsilon",
         
     | 
| 772 | 
         
            +
                    thresholding: bool = False,
         
     | 
| 773 | 
         
            +
                    dynamic_thresholding_ratio: float = 0.995,
         
     | 
| 774 | 
         
            +
                    clip_sample_range: float = 1.0,
         
     | 
| 775 | 
         
            +
                    sample_max_value: float = 1.0,
         
     | 
| 776 | 
         
            +
                    timestep_spacing: str = "leading",
         
     | 
| 777 | 
         
            +
                    rescale_betas_zero_snr: bool = False,
         
     | 
| 778 | 
         
            +
                ):
         
     | 
| 779 | 
         
            +
                    if trained_betas is not None:
         
     | 
| 780 | 
         
            +
                        self.betas = torch.tensor(trained_betas, dtype=torch.float32)
         
     | 
| 781 | 
         
            +
                    elif beta_schedule == "linear":
         
     | 
| 782 | 
         
            +
                        self.betas = torch.linspace(
         
     | 
| 783 | 
         
            +
                            beta_start, beta_end, num_train_timesteps, dtype=torch.float32
         
     | 
| 784 | 
         
            +
                        )
         
     | 
| 785 | 
         
            +
                    elif beta_schedule == "scaled_linear":
         
     | 
| 786 | 
         
            +
                        # this schedule is very specific to the latent diffusion model.
         
     | 
| 787 | 
         
            +
                        self.betas = (
         
     | 
| 788 | 
         
            +
                            torch.linspace(
         
     | 
| 789 | 
         
            +
                                beta_start**0.5,
         
     | 
| 790 | 
         
            +
                                beta_end**0.5,
         
     | 
| 791 | 
         
            +
                                num_train_timesteps,
         
     | 
| 792 | 
         
            +
                                dtype=torch.float32,
         
     | 
| 793 | 
         
            +
                            )
         
     | 
| 794 | 
         
            +
                            ** 2
         
     | 
| 795 | 
         
            +
                        )
         
     | 
| 796 | 
         
            +
                    elif beta_schedule == "squaredcos_cap_v2":
         
     | 
| 797 | 
         
            +
                        # Glide cosine schedule
         
     | 
| 798 | 
         
            +
                        self.betas = betas_for_alpha_bar(num_train_timesteps)
         
     | 
| 799 | 
         
            +
                    else:
         
     | 
| 800 | 
         
            +
                        raise NotImplementedError(
         
     | 
| 801 | 
         
            +
                            f"{beta_schedule} does is not implemented for {self.__class__}"
         
     | 
| 802 | 
         
            +
                        )
         
     | 
| 803 | 
         
            +
             
     | 
| 804 | 
         
            +
                    # Rescale for zero SNR
         
     | 
| 805 | 
         
            +
                    if rescale_betas_zero_snr:
         
     | 
| 806 | 
         
            +
                        self.betas = rescale_zero_terminal_snr(self.betas)
         
     | 
| 807 | 
         
            +
             
     | 
| 808 | 
         
            +
                    self.alphas = 1.0 - self.betas
         
     | 
| 809 | 
         
            +
                    self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
         
     | 
| 810 | 
         
            +
             
     | 
| 811 | 
         
            +
                    # At every step in ddim, we are looking into the previous alphas_cumprod
         
     | 
| 812 | 
         
            +
                    # For the final step, there is no previous alphas_cumprod because we are already at 0
         
     | 
| 813 | 
         
            +
                    # `set_alpha_to_one` decides whether we set this parameter simply to one or
         
     | 
| 814 | 
         
            +
                    # whether we use the final alpha of the "non-previous" one.
         
     | 
| 815 | 
         
            +
                    self.final_alpha_cumprod = (
         
     | 
| 816 | 
         
            +
                        torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
         
     | 
| 817 | 
         
            +
                    )
         
     | 
| 818 | 
         
            +
             
     | 
| 819 | 
         
            +
                    # standard deviation of the initial noise distribution
         
     | 
| 820 | 
         
            +
                    self.init_noise_sigma = 1.0
         
     | 
| 821 | 
         
            +
             
     | 
| 822 | 
         
            +
                    # setable values
         
     | 
| 823 | 
         
            +
                    self.num_inference_steps = None
         
     | 
| 824 | 
         
            +
                    self.timesteps = torch.from_numpy(
         
     | 
| 825 | 
         
            +
                        np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)
         
     | 
| 826 | 
         
            +
                    )
         
     | 
| 827 | 
         
            +
             
     | 
| 828 | 
         
            +
                def scale_model_input(
         
     | 
| 829 | 
         
            +
                    self, sample: torch.FloatTensor, timestep: Optional[int] = None
         
     | 
| 830 | 
         
            +
                ) -> torch.FloatTensor:
         
     | 
| 831 | 
         
            +
                    """
         
     | 
| 832 | 
         
            +
                    Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
         
     | 
| 833 | 
         
            +
                    current timestep.
         
     | 
| 834 | 
         
            +
                    Args:
         
     | 
| 835 | 
         
            +
                        sample (`torch.FloatTensor`):
         
     | 
| 836 | 
         
            +
                            The input sample.
         
     | 
| 837 | 
         
            +
                        timestep (`int`, *optional*):
         
     | 
| 838 | 
         
            +
                            The current timestep in the diffusion chain.
         
     | 
| 839 | 
         
            +
                    Returns:
         
     | 
| 840 | 
         
            +
                        `torch.FloatTensor`:
         
     | 
| 841 | 
         
            +
                            A scaled input sample.
         
     | 
| 842 | 
         
            +
                    """
         
     | 
| 843 | 
         
            +
                    return sample
         
     | 
| 844 | 
         
            +
             
     | 
| 845 | 
         
            +
                def _get_variance(self, timestep, prev_timestep):
         
     | 
| 846 | 
         
            +
                    alpha_prod_t = self.alphas_cumprod[timestep]
         
     | 
| 847 | 
         
            +
                    alpha_prod_t_prev = (
         
     | 
| 848 | 
         
            +
                        self.alphas_cumprod[prev_timestep]
         
     | 
| 849 | 
         
            +
                        if prev_timestep >= 0
         
     | 
| 850 | 
         
            +
                        else self.final_alpha_cumprod
         
     | 
| 851 | 
         
            +
                    )
         
     | 
| 852 | 
         
            +
                    beta_prod_t = 1 - alpha_prod_t
         
     | 
| 853 | 
         
            +
                    beta_prod_t_prev = 1 - alpha_prod_t_prev
         
     | 
| 854 | 
         
            +
             
     | 
| 855 | 
         
            +
                    variance = (beta_prod_t_prev / beta_prod_t) * (
         
     | 
| 856 | 
         
            +
                        1 - alpha_prod_t / alpha_prod_t_prev
         
     | 
| 857 | 
         
            +
                    )
         
     | 
| 858 | 
         
            +
             
     | 
| 859 | 
         
            +
                    return variance
         
     | 
| 860 | 
         
            +
             
     | 
| 861 | 
         
            +
                # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
         
     | 
| 862 | 
         
            +
                def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
         
     | 
| 863 | 
         
            +
                    """
         
     | 
| 864 | 
         
            +
                    "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
         
     | 
| 865 | 
         
            +
                    prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
         
     | 
| 866 | 
         
            +
                    s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
         
     | 
| 867 | 
         
            +
                    pixels from saturation at each step. We find that dynamic thresholding results in significantly better
         
     | 
| 868 | 
         
            +
                    photorealism as well as better image-text alignment, especially when using very large guidance weights."
         
     | 
| 869 | 
         
            +
                    https://arxiv.org/abs/2205.11487
         
     | 
| 870 | 
         
            +
                    """
         
     | 
| 871 | 
         
            +
                    dtype = sample.dtype
         
     | 
| 872 | 
         
            +
                    batch_size, channels, height, width = sample.shape
         
     | 
| 873 | 
         
            +
             
     | 
| 874 | 
         
            +
                    if dtype not in (torch.float32, torch.float64):
         
     | 
| 875 | 
         
            +
                        sample = (
         
     | 
| 876 | 
         
            +
                            sample.float()
         
     | 
| 877 | 
         
            +
                        )  # upcast for quantile calculation, and clamp not implemented for cpu half
         
     | 
| 878 | 
         
            +
             
     | 
| 879 | 
         
            +
                    # Flatten sample for doing quantile calculation along each image
         
     | 
| 880 | 
         
            +
                    sample = sample.reshape(batch_size, channels * height * width)
         
     | 
| 881 | 
         
            +
             
     | 
| 882 | 
         
            +
                    abs_sample = sample.abs()  # "a certain percentile absolute pixel value"
         
     | 
| 883 | 
         
            +
             
     | 
| 884 | 
         
            +
                    s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
         
     | 
| 885 | 
         
            +
                    s = torch.clamp(
         
     | 
| 886 | 
         
            +
                        s, min=1, max=self.config.sample_max_value
         
     | 
| 887 | 
         
            +
                    )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]
         
     | 
| 888 | 
         
            +
             
     | 
| 889 | 
         
            +
                    s = s.unsqueeze(1)  # (batch_size, 1) because clamp will broadcast along dim=0
         
     | 
| 890 | 
         
            +
                    sample = (
         
     | 
| 891 | 
         
            +
                        torch.clamp(sample, -s, s) / s
         
     | 
| 892 | 
         
            +
                    )  # "we threshold xt0 to the range [-s, s] and then divide by s"
         
     | 
| 893 | 
         
            +
             
     | 
| 894 | 
         
            +
                    sample = sample.reshape(batch_size, channels, height, width)
         
     | 
| 895 | 
         
            +
                    sample = sample.to(dtype)
         
     | 
| 896 | 
         
            +
             
     | 
| 897 | 
         
            +
                    return sample
         
     | 
| 898 | 
         
            +
             
     | 
| 899 | 
         
            +
                def set_timesteps(
         
     | 
| 900 | 
         
            +
                    self,
         
     | 
| 901 | 
         
            +
                    stength,
         
     | 
| 902 | 
         
            +
                    num_inference_steps: int,
         
     | 
| 903 | 
         
            +
                    lcm_origin_steps: int,
         
     | 
| 904 | 
         
            +
                    device: Union[str, torch.device] = None,
         
     | 
| 905 | 
         
            +
                ):
         
     | 
| 906 | 
         
            +
                    """
         
     | 
| 907 | 
         
            +
                    Sets the discrete timesteps used for the diffusion chain (to be run before inference).
         
     | 
| 908 | 
         
            +
                    Args:
         
     | 
| 909 | 
         
            +
                        num_inference_steps (`int`):
         
     | 
| 910 | 
         
            +
                            The number of diffusion steps used when generating samples with a pre-trained model.
         
     | 
| 911 | 
         
            +
                    """
         
     | 
| 912 | 
         
            +
             
     | 
| 913 | 
         
            +
                    if num_inference_steps > self.config.num_train_timesteps:
         
     | 
| 914 | 
         
            +
                        raise ValueError(
         
     | 
| 915 | 
         
            +
                            f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
         
     | 
| 916 | 
         
            +
                            f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
         
     | 
| 917 | 
         
            +
                            f" maximal {self.config.num_train_timesteps} timesteps."
         
     | 
| 918 | 
         
            +
                        )
         
     | 
| 919 | 
         
            +
             
     | 
| 920 | 
         
            +
                    self.num_inference_steps = num_inference_steps
         
     | 
| 921 | 
         
            +
             
     | 
| 922 | 
         
            +
                    # LCM Timesteps Setting:  # Linear Spacing
         
     | 
| 923 | 
         
            +
                    c = self.config.num_train_timesteps // lcm_origin_steps
         
     | 
| 924 | 
         
            +
                    lcm_origin_timesteps = (
         
     | 
| 925 | 
         
            +
                        np.asarray(list(range(1, int(lcm_origin_steps * stength) + 1))) * c - 1
         
     | 
| 926 | 
         
            +
                    )  # LCM Training  Steps Schedule
         
     | 
| 927 | 
         
            +
                    skipping_step = max(len(lcm_origin_timesteps) // num_inference_steps, 1)
         
     | 
| 928 | 
         
            +
                    timesteps = lcm_origin_timesteps[::-skipping_step][
         
     | 
| 929 | 
         
            +
                        :num_inference_steps
         
     | 
| 930 | 
         
            +
                    ]  # LCM Inference Steps Schedule
         
     | 
| 931 | 
         
            +
             
     | 
| 932 | 
         
            +
                    self.timesteps = torch.from_numpy(timesteps.copy()).to(device)
         
     | 
| 933 | 
         
            +
             
     | 
| 934 | 
         
            +
                def get_scalings_for_boundary_condition_discrete(self, t):
         
     | 
| 935 | 
         
            +
                    self.sigma_data = 0.5  # Default: 0.5
         
     | 
| 936 | 
         
            +
             
     | 
| 937 | 
         
            +
                    # By dividing 0.1: This is almost a delta function at t=0.
         
     | 
| 938 | 
         
            +
                    c_skip = self.sigma_data**2 / ((t / 0.1) ** 2 + self.sigma_data**2)
         
     | 
| 939 | 
         
            +
                    c_out = (t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data**2) ** 0.5
         
     | 
| 940 | 
         
            +
                    return c_skip, c_out
         
     | 
| 941 | 
         
            +
             
     | 
| 942 | 
         
            +
                def step(
         
     | 
| 943 | 
         
            +
                    self,
         
     | 
| 944 | 
         
            +
                    model_output: torch.FloatTensor,
         
     | 
| 945 | 
         
            +
                    timeindex: int,
         
     | 
| 946 | 
         
            +
                    timestep: int,
         
     | 
| 947 | 
         
            +
                    sample: torch.FloatTensor,
         
     | 
| 948 | 
         
            +
                    eta: float = 0.0,
         
     | 
| 949 | 
         
            +
                    use_clipped_model_output: bool = False,
         
     | 
| 950 | 
         
            +
                    generator=None,
         
     | 
| 951 | 
         
            +
                    variance_noise: Optional[torch.FloatTensor] = None,
         
     | 
| 952 | 
         
            +
                    return_dict: bool = True,
         
     | 
| 953 | 
         
            +
                ) -> Union[LCMSchedulerOutput, Tuple]:
         
     | 
| 954 | 
         
            +
                    """
         
     | 
| 955 | 
         
            +
                    Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
         
     | 
| 956 | 
         
            +
                    process from the learned model outputs (most often the predicted noise).
         
     | 
| 957 | 
         
            +
                    Args:
         
     | 
| 958 | 
         
            +
                        model_output (`torch.FloatTensor`):
         
     | 
| 959 | 
         
            +
                            The direct output from learned diffusion model.
         
     | 
| 960 | 
         
            +
                        timestep (`float`):
         
     | 
| 961 | 
         
            +
                            The current discrete timestep in the diffusion chain.
         
     | 
| 962 | 
         
            +
                        sample (`torch.FloatTensor`):
         
     | 
| 963 | 
         
            +
                            A current instance of a sample created by the diffusion process.
         
     | 
| 964 | 
         
            +
                        eta (`float`):
         
     | 
| 965 | 
         
            +
                            The weight of noise for added noise in diffusion step.
         
     | 
| 966 | 
         
            +
                        use_clipped_model_output (`bool`, defaults to `False`):
         
     | 
| 967 | 
         
            +
                            If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
         
     | 
| 968 | 
         
            +
                            because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
         
     | 
| 969 | 
         
            +
                            clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
         
     | 
| 970 | 
         
            +
                            `use_clipped_model_output` has no effect.
         
     | 
| 971 | 
         
            +
                        generator (`torch.Generator`, *optional*):
         
     | 
| 972 | 
         
            +
                            A random number generator.
         
     | 
| 973 | 
         
            +
                        variance_noise (`torch.FloatTensor`):
         
     | 
| 974 | 
         
            +
                            Alternative to generating noise with `generator` by directly providing the noise for the variance
         
     | 
| 975 | 
         
            +
                            itself. Useful for methods such as [`CycleDiffusion`].
         
     | 
| 976 | 
         
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         
     | 
| 977 | 
         
            +
                            Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
         
     | 
| 978 | 
         
            +
                    Returns:
         
     | 
| 979 | 
         
            +
                        [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
         
     | 
| 980 | 
         
            +
                            If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
         
     | 
| 981 | 
         
            +
                            tuple is returned where the first element is the sample tensor.
         
     | 
| 982 | 
         
            +
                    """
         
     | 
| 983 | 
         
            +
                    if self.num_inference_steps is None:
         
     | 
| 984 | 
         
            +
                        raise ValueError(
         
     | 
| 985 | 
         
            +
                            "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
         
     | 
| 986 | 
         
            +
                        )
         
     | 
| 987 | 
         
            +
             
     | 
| 988 | 
         
            +
                    # 1. get previous step value
         
     | 
| 989 | 
         
            +
                    prev_timeindex = timeindex + 1
         
     | 
| 990 | 
         
            +
                    if prev_timeindex < len(self.timesteps):
         
     | 
| 991 | 
         
            +
                        prev_timestep = self.timesteps[prev_timeindex]
         
     | 
| 992 | 
         
            +
                    else:
         
     | 
| 993 | 
         
            +
                        prev_timestep = timestep
         
     | 
| 994 | 
         
            +
             
     | 
| 995 | 
         
            +
                    # 2. compute alphas, betas
         
     | 
| 996 | 
         
            +
                    alpha_prod_t = self.alphas_cumprod[timestep]
         
     | 
| 997 | 
         
            +
                    alpha_prod_t_prev = (
         
     | 
| 998 | 
         
            +
                        self.alphas_cumprod[prev_timestep]
         
     | 
| 999 | 
         
            +
                        if prev_timestep >= 0
         
     | 
| 1000 | 
         
            +
                        else self.final_alpha_cumprod
         
     | 
| 1001 | 
         
            +
                    )
         
     | 
| 1002 | 
         
            +
             
     | 
| 1003 | 
         
            +
                    beta_prod_t = 1 - alpha_prod_t
         
     | 
| 1004 | 
         
            +
                    beta_prod_t_prev = 1 - alpha_prod_t_prev
         
     | 
| 1005 | 
         
            +
             
     | 
| 1006 | 
         
            +
                    # 3. Get scalings for boundary conditions
         
     | 
| 1007 | 
         
            +
                    c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
         
     | 
| 1008 | 
         
            +
             
     | 
| 1009 | 
         
            +
                    # 4. Different Parameterization:
         
     | 
| 1010 | 
         
            +
                    parameterization = self.config.prediction_type
         
     | 
| 1011 | 
         
            +
             
     | 
| 1012 | 
         
            +
                    if parameterization == "epsilon":  # noise-prediction
         
     | 
| 1013 | 
         
            +
                        pred_x0 = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
         
     | 
| 1014 | 
         
            +
             
     | 
| 1015 | 
         
            +
                    elif parameterization == "sample":  # x-prediction
         
     | 
| 1016 | 
         
            +
                        pred_x0 = model_output
         
     | 
| 1017 | 
         
            +
             
     | 
| 1018 | 
         
            +
                    elif parameterization == "v_prediction":  # v-prediction
         
     | 
| 1019 | 
         
            +
                        pred_x0 = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
         
     | 
| 1020 | 
         
            +
             
     | 
| 1021 | 
         
            +
                    # 4. Denoise model output using boundary conditions
         
     | 
| 1022 | 
         
            +
                    denoised = c_out * pred_x0 + c_skip * sample
         
     | 
| 1023 | 
         
            +
             
     | 
| 1024 | 
         
            +
                    # 5. Sample z ~ N(0, I), For MultiStep Inference
         
     | 
| 1025 | 
         
            +
                    # Noise is not used for one-step sampling.
         
     | 
| 1026 | 
         
            +
                    if len(self.timesteps) > 1:
         
     | 
| 1027 | 
         
            +
                        noise = torch.randn(model_output.shape).to(model_output.device)
         
     | 
| 1028 | 
         
            +
                        prev_sample = (
         
     | 
| 1029 | 
         
            +
                            alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
         
     | 
| 1030 | 
         
            +
                        )
         
     | 
| 1031 | 
         
            +
                    else:
         
     | 
| 1032 | 
         
            +
                        prev_sample = denoised
         
     | 
| 1033 | 
         
            +
             
     | 
| 1034 | 
         
            +
                    if not return_dict:
         
     | 
| 1035 | 
         
            +
                        return (prev_sample, denoised)
         
     | 
| 1036 | 
         
            +
             
     | 
| 1037 | 
         
            +
                    return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
         
     | 
| 1038 | 
         
            +
             
     | 
| 1039 | 
         
            +
                # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
         
     | 
| 1040 | 
         
            +
                def add_noise(
         
     | 
| 1041 | 
         
            +
                    self,
         
     | 
| 1042 | 
         
            +
                    original_samples: torch.FloatTensor,
         
     | 
| 1043 | 
         
            +
                    noise: torch.FloatTensor,
         
     | 
| 1044 | 
         
            +
                    timesteps: torch.IntTensor,
         
     | 
| 1045 | 
         
            +
                ) -> torch.FloatTensor:
         
     | 
| 1046 | 
         
            +
                    # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
         
     | 
| 1047 | 
         
            +
                    alphas_cumprod = self.alphas_cumprod.to(
         
     | 
| 1048 | 
         
            +
                        device=original_samples.device, dtype=original_samples.dtype
         
     | 
| 1049 | 
         
            +
                    )
         
     | 
| 1050 | 
         
            +
                    timesteps = timesteps.to(original_samples.device)
         
     | 
| 1051 | 
         
            +
             
     | 
| 1052 | 
         
            +
                    sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
         
     | 
| 1053 | 
         
            +
                    sqrt_alpha_prod = sqrt_alpha_prod.flatten()
         
     | 
| 1054 | 
         
            +
                    while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
         
     | 
| 1055 | 
         
            +
                        sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
         
     | 
| 1056 | 
         
            +
             
     | 
| 1057 | 
         
            +
                    sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
         
     | 
| 1058 | 
         
            +
                    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
         
     | 
| 1059 | 
         
            +
                    while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
         
     | 
| 1060 | 
         
            +
                        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
         
     | 
| 1061 | 
         
            +
             
     | 
| 1062 | 
         
            +
                    noisy_samples = (
         
     | 
| 1063 | 
         
            +
                        sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
         
     | 
| 1064 | 
         
            +
                    )
         
     | 
| 1065 | 
         
            +
                    return noisy_samples
         
     | 
| 1066 | 
         
            +
             
     | 
| 1067 | 
         
            +
                # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
         
     | 
| 1068 | 
         
            +
                def get_velocity(
         
     | 
| 1069 | 
         
            +
                    self,
         
     | 
| 1070 | 
         
            +
                    sample: torch.FloatTensor,
         
     | 
| 1071 | 
         
            +
                    noise: torch.FloatTensor,
         
     | 
| 1072 | 
         
            +
                    timesteps: torch.IntTensor,
         
     | 
| 1073 | 
         
            +
                ) -> torch.FloatTensor:
         
     | 
| 1074 | 
         
            +
                    # Make sure alphas_cumprod and timestep have same device and dtype as sample
         
     | 
| 1075 | 
         
            +
                    alphas_cumprod = self.alphas_cumprod.to(
         
     | 
| 1076 | 
         
            +
                        device=sample.device, dtype=sample.dtype
         
     | 
| 1077 | 
         
            +
                    )
         
     | 
| 1078 | 
         
            +
                    timesteps = timesteps.to(sample.device)
         
     | 
| 1079 | 
         
            +
             
     | 
| 1080 | 
         
            +
                    sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
         
     | 
| 1081 | 
         
            +
                    sqrt_alpha_prod = sqrt_alpha_prod.flatten()
         
     | 
| 1082 | 
         
            +
                    while len(sqrt_alpha_prod.shape) < len(sample.shape):
         
     | 
| 1083 | 
         
            +
                        sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
         
     | 
| 1084 | 
         
            +
             
     | 
| 1085 | 
         
            +
                    sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
         
     | 
| 1086 | 
         
            +
                    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
         
     | 
| 1087 | 
         
            +
                    while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
         
     | 
| 1088 | 
         
            +
                        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
         
     | 
| 1089 | 
         
            +
             
     | 
| 1090 | 
         
            +
                    velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
         
     | 
| 1091 | 
         
            +
                    return velocity
         
     | 
| 1092 | 
         
            +
             
     | 
| 1093 | 
         
            +
                def __len__(self):
         
     | 
| 1094 | 
         
            +
                    return self.config.num_train_timesteps
         
     | 
    	
        requirements.txt
    CHANGED
    
    | 
         @@ -7,4 +7,5 @@ fastapi==0.104.0 
     | 
|
| 7 | 
         
             
            uvicorn==0.23.2
         
     | 
| 8 | 
         
             
            Pillow==10.1.0
         
     | 
| 9 | 
         
             
            accelerate==0.24.0
         
     | 
| 10 | 
         
            -
            compel==2.0.2
         
     | 
| 
         | 
| 
         | 
|
| 7 | 
         
             
            uvicorn==0.23.2
         
     | 
| 8 | 
         
             
            Pillow==10.1.0
         
     | 
| 9 | 
         
             
            accelerate==0.24.0
         
     | 
| 10 | 
         
            +
            compel==2.0.2
         
     | 
| 11 | 
         
            +
            controlnet-aux==0.0.7
         
     |