Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Duplicate from nyanko7/sd-diffusers-webui
Browse filesCo-authored-by: Nyanko <nyanko7@users.noreply.huggingface.co>
- .gitattributes +34 -0
- Dockerfile +22 -0
- README.md +14 -0
- app.py +878 -0
- modules/lora.py +183 -0
- modules/model.py +897 -0
- modules/prompt_parser.py +391 -0
- modules/safe.py +188 -0
    	
        .gitattributes
    ADDED
    
    | @@ -0,0 +1,34 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            *.7z filter=lfs diff=lfs merge=lfs -text
         | 
| 2 | 
            +
            *.arrow filter=lfs diff=lfs merge=lfs -text
         | 
| 3 | 
            +
            *.bin filter=lfs diff=lfs merge=lfs -text
         | 
| 4 | 
            +
            *.bz2 filter=lfs diff=lfs merge=lfs -text
         | 
| 5 | 
            +
            *.ckpt filter=lfs diff=lfs merge=lfs -text
         | 
| 6 | 
            +
            *.ftz filter=lfs diff=lfs merge=lfs -text
         | 
| 7 | 
            +
            *.gz filter=lfs diff=lfs merge=lfs -text
         | 
| 8 | 
            +
            *.h5 filter=lfs diff=lfs merge=lfs -text
         | 
| 9 | 
            +
            *.joblib filter=lfs diff=lfs merge=lfs -text
         | 
| 10 | 
            +
            *.lfs.* filter=lfs diff=lfs merge=lfs -text
         | 
| 11 | 
            +
            *.mlmodel filter=lfs diff=lfs merge=lfs -text
         | 
| 12 | 
            +
            *.model filter=lfs diff=lfs merge=lfs -text
         | 
| 13 | 
            +
            *.msgpack filter=lfs diff=lfs merge=lfs -text
         | 
| 14 | 
            +
            *.npy filter=lfs diff=lfs merge=lfs -text
         | 
| 15 | 
            +
            *.npz filter=lfs diff=lfs merge=lfs -text
         | 
| 16 | 
            +
            *.onnx filter=lfs diff=lfs merge=lfs -text
         | 
| 17 | 
            +
            *.ot filter=lfs diff=lfs merge=lfs -text
         | 
| 18 | 
            +
            *.parquet filter=lfs diff=lfs merge=lfs -text
         | 
| 19 | 
            +
            *.pb filter=lfs diff=lfs merge=lfs -text
         | 
| 20 | 
            +
            *.pickle filter=lfs diff=lfs merge=lfs -text
         | 
| 21 | 
            +
            *.pkl filter=lfs diff=lfs merge=lfs -text
         | 
| 22 | 
            +
            *.pt filter=lfs diff=lfs merge=lfs -text
         | 
| 23 | 
            +
            *.pth filter=lfs diff=lfs merge=lfs -text
         | 
| 24 | 
            +
            *.rar filter=lfs diff=lfs merge=lfs -text
         | 
| 25 | 
            +
            *.safetensors filter=lfs diff=lfs merge=lfs -text
         | 
| 26 | 
            +
            saved_model/**/* filter=lfs diff=lfs merge=lfs -text
         | 
| 27 | 
            +
            *.tar.* filter=lfs diff=lfs merge=lfs -text
         | 
| 28 | 
            +
            *.tflite filter=lfs diff=lfs merge=lfs -text
         | 
| 29 | 
            +
            *.tgz filter=lfs diff=lfs merge=lfs -text
         | 
| 30 | 
            +
            *.wasm filter=lfs diff=lfs merge=lfs -text
         | 
| 31 | 
            +
            *.xz filter=lfs diff=lfs merge=lfs -text
         | 
| 32 | 
            +
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 33 | 
            +
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
            +
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
    	
        Dockerfile
    ADDED
    
    | @@ -0,0 +1,22 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Dockerfile Public T4
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04
         | 
| 4 | 
            +
            ENV DEBIAN_FRONTEND noninteractive
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            WORKDIR /content
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            RUN apt-get update -y && apt-get upgrade -y && apt-get install -y libgl1 libglib2.0-0 wget git git-lfs python3-pip python-is-python3 && pip3 install --upgrade pip
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            RUN pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchsde --extra-index-url https://download.pytorch.org/whl/cu113
         | 
| 11 | 
            +
            RUN pip install https://github.com/camenduru/stable-diffusion-webui-colab/releases/download/0.0.16/xformers-0.0.16+814314d.d20230118-cp310-cp310-linux_x86_64.whl
         | 
| 12 | 
            +
            RUN pip install --pre triton
         | 
| 13 | 
            +
            RUN pip install numexpr einops transformers k_diffusion safetensors gradio diffusers==0.12.1
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            ADD . .
         | 
| 16 | 
            +
            RUN adduser --disabled-password --gecos '' user
         | 
| 17 | 
            +
            RUN chown -R user:user /content
         | 
| 18 | 
            +
            RUN chmod -R 777 /content
         | 
| 19 | 
            +
            USER user
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            EXPOSE 7860
         | 
| 22 | 
            +
            CMD python /content/app.py
         | 
    	
        README.md
    ADDED
    
    | @@ -0,0 +1,14 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            title: Sd Diffusers Webui
         | 
| 3 | 
            +
            emoji: 🐳
         | 
| 4 | 
            +
            colorFrom: purple
         | 
| 5 | 
            +
            colorTo: gray
         | 
| 6 | 
            +
            sdk: docker
         | 
| 7 | 
            +
            sdk_version: 3.9
         | 
| 8 | 
            +
            pinned: false
         | 
| 9 | 
            +
            license: openrail
         | 
| 10 | 
            +
            app_port: 7860
         | 
| 11 | 
            +
            duplicated_from: nyanko7/sd-diffusers-webui
         | 
| 12 | 
            +
            ---
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,878 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import random
         | 
| 2 | 
            +
            import tempfile
         | 
| 3 | 
            +
            import time
         | 
| 4 | 
            +
            import gradio as gr
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            import math
         | 
| 8 | 
            +
            import re
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from gradio import inputs
         | 
| 11 | 
            +
            from diffusers import (
         | 
| 12 | 
            +
                AutoencoderKL,
         | 
| 13 | 
            +
                DDIMScheduler,
         | 
| 14 | 
            +
                UNet2DConditionModel,
         | 
| 15 | 
            +
            )
         | 
| 16 | 
            +
            from modules.model import (
         | 
| 17 | 
            +
                CrossAttnProcessor,
         | 
| 18 | 
            +
                StableDiffusionPipeline,
         | 
| 19 | 
            +
            )
         | 
| 20 | 
            +
            from torchvision import transforms
         | 
| 21 | 
            +
            from transformers import CLIPTokenizer, CLIPTextModel
         | 
| 22 | 
            +
            from PIL import Image
         | 
| 23 | 
            +
            from pathlib import Path
         | 
| 24 | 
            +
            from safetensors.torch import load_file
         | 
| 25 | 
            +
            import modules.safe as _
         | 
| 26 | 
            +
            from modules.lora import LoRANetwork
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            models = [
         | 
| 29 | 
            +
                ("AbyssOrangeMix2", "Korakoe/AbyssOrangeMix2-HF", 2),
         | 
| 30 | 
            +
                ("Pastal Mix", "andite/pastel-mix", 2),
         | 
| 31 | 
            +
                ("Basil Mix", "nuigurumi/basil_mix", 2)
         | 
| 32 | 
            +
            ]
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            keep_vram = ["Korakoe/AbyssOrangeMix2-HF", "andite/pastel-mix"]
         | 
| 35 | 
            +
            base_name, base_model, clip_skip = models[0]
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            samplers_k_diffusion = [
         | 
| 38 | 
            +
                ("Euler a", "sample_euler_ancestral", {}),
         | 
| 39 | 
            +
                ("Euler", "sample_euler", {}),
         | 
| 40 | 
            +
                ("LMS", "sample_lms", {}),
         | 
| 41 | 
            +
                ("Heun", "sample_heun", {}),
         | 
| 42 | 
            +
                ("DPM2", "sample_dpm_2", {"discard_next_to_last_sigma": True}),
         | 
| 43 | 
            +
                ("DPM2 a", "sample_dpm_2_ancestral", {"discard_next_to_last_sigma": True}),
         | 
| 44 | 
            +
                ("DPM++ 2S a", "sample_dpmpp_2s_ancestral", {}),
         | 
| 45 | 
            +
                ("DPM++ 2M", "sample_dpmpp_2m", {}),
         | 
| 46 | 
            +
                ("DPM++ SDE", "sample_dpmpp_sde", {}),
         | 
| 47 | 
            +
                ("LMS Karras", "sample_lms", {"scheduler": "karras"}),
         | 
| 48 | 
            +
                ("DPM2 Karras", "sample_dpm_2", {"scheduler": "karras", "discard_next_to_last_sigma": True}),
         | 
| 49 | 
            +
                ("DPM2 a Karras", "sample_dpm_2_ancestral", {"scheduler": "karras", "discard_next_to_last_sigma": True}),
         | 
| 50 | 
            +
                ("DPM++ 2S a Karras", "sample_dpmpp_2s_ancestral", {"scheduler": "karras"}),
         | 
| 51 | 
            +
                ("DPM++ 2M Karras", "sample_dpmpp_2m", {"scheduler": "karras"}),
         | 
| 52 | 
            +
                ("DPM++ SDE Karras", "sample_dpmpp_sde", {"scheduler": "karras"}),
         | 
| 53 | 
            +
            ]
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            # samplers_diffusers = [
         | 
| 56 | 
            +
            #     ("DDIMScheduler", "diffusers.schedulers.DDIMScheduler", {})
         | 
| 57 | 
            +
            #     ("DDPMScheduler", "diffusers.schedulers.DDPMScheduler", {})
         | 
| 58 | 
            +
            #     ("DEISMultistepScheduler", "diffusers.schedulers.DEISMultistepScheduler", {})
         | 
| 59 | 
            +
            # ]
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            start_time = time.time()
         | 
| 62 | 
            +
            timeout = 90
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            scheduler = DDIMScheduler.from_pretrained(
         | 
| 65 | 
            +
                base_model,
         | 
| 66 | 
            +
                subfolder="scheduler",
         | 
| 67 | 
            +
            )
         | 
| 68 | 
            +
            vae = AutoencoderKL.from_pretrained(
         | 
| 69 | 
            +
                "stabilityai/sd-vae-ft-ema", 
         | 
| 70 | 
            +
                torch_dtype=torch.float16
         | 
| 71 | 
            +
            )
         | 
| 72 | 
            +
            text_encoder = CLIPTextModel.from_pretrained(
         | 
| 73 | 
            +
                base_model,
         | 
| 74 | 
            +
                subfolder="text_encoder",
         | 
| 75 | 
            +
                torch_dtype=torch.float16,
         | 
| 76 | 
            +
            )
         | 
| 77 | 
            +
            tokenizer = CLIPTokenizer.from_pretrained(
         | 
| 78 | 
            +
                base_model,
         | 
| 79 | 
            +
                subfolder="tokenizer",
         | 
| 80 | 
            +
                torch_dtype=torch.float16,
         | 
| 81 | 
            +
            )
         | 
| 82 | 
            +
            unet = UNet2DConditionModel.from_pretrained(
         | 
| 83 | 
            +
                base_model,
         | 
| 84 | 
            +
                subfolder="unet",
         | 
| 85 | 
            +
                torch_dtype=torch.float16,
         | 
| 86 | 
            +
            )
         | 
| 87 | 
            +
            pipe = StableDiffusionPipeline(
         | 
| 88 | 
            +
                text_encoder=text_encoder,
         | 
| 89 | 
            +
                tokenizer=tokenizer,
         | 
| 90 | 
            +
                unet=unet,
         | 
| 91 | 
            +
                vae=vae,
         | 
| 92 | 
            +
                scheduler=scheduler,
         | 
| 93 | 
            +
            )
         | 
| 94 | 
            +
             | 
| 95 | 
            +
            unet.set_attn_processor(CrossAttnProcessor)
         | 
| 96 | 
            +
            pipe.setup_text_encoder(clip_skip, text_encoder)
         | 
| 97 | 
            +
            if torch.cuda.is_available():
         | 
| 98 | 
            +
                pipe = pipe.to("cuda")
         | 
| 99 | 
            +
             | 
| 100 | 
            +
            def get_model_list():
         | 
| 101 | 
            +
                return models
         | 
| 102 | 
            +
             | 
| 103 | 
            +
            te_cache = {
         | 
| 104 | 
            +
                base_model: text_encoder
         | 
| 105 | 
            +
            }
         | 
| 106 | 
            +
             | 
| 107 | 
            +
            unet_cache = {
         | 
| 108 | 
            +
                base_model: unet
         | 
| 109 | 
            +
            }
         | 
| 110 | 
            +
             | 
| 111 | 
            +
            lora_cache = {
         | 
| 112 | 
            +
                base_model: LoRANetwork(text_encoder, unet)
         | 
| 113 | 
            +
            }
         | 
| 114 | 
            +
             | 
| 115 | 
            +
            te_base_weight_length = text_encoder.get_input_embeddings().weight.data.shape[0]
         | 
| 116 | 
            +
            original_prepare_for_tokenization = tokenizer.prepare_for_tokenization
         | 
| 117 | 
            +
            current_model = base_model
         | 
| 118 | 
            +
             | 
| 119 | 
            +
            def setup_model(name, lora_state=None, lora_scale=1.0):
         | 
| 120 | 
            +
                global pipe, current_model
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                keys = [k[0] for k in models]
         | 
| 123 | 
            +
                model = models[keys.index(name)][1]
         | 
| 124 | 
            +
                if model not in unet_cache:
         | 
| 125 | 
            +
                    unet = UNet2DConditionModel.from_pretrained(model, subfolder="unet", torch_dtype=torch.float16)
         | 
| 126 | 
            +
                    text_encoder = CLIPTextModel.from_pretrained(model, subfolder="text_encoder", torch_dtype=torch.float16)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    unet_cache[model] = unet
         | 
| 129 | 
            +
                    te_cache[model] = text_encoder
         | 
| 130 | 
            +
                    lora_cache[model] = LoRANetwork(text_encoder, unet)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                if current_model != model:
         | 
| 133 | 
            +
                    if current_model not in keep_vram:
         | 
| 134 | 
            +
                        # offload current model
         | 
| 135 | 
            +
                        unet_cache[current_model].to("cpu")
         | 
| 136 | 
            +
                        te_cache[current_model].to("cpu")
         | 
| 137 | 
            +
                        lora_cache[current_model].to("cpu")
         | 
| 138 | 
            +
                    current_model = model
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                local_te, local_unet, local_lora, = te_cache[model], unet_cache[model], lora_cache[model]
         | 
| 141 | 
            +
                local_unet.set_attn_processor(CrossAttnProcessor())
         | 
| 142 | 
            +
                local_lora.reset()
         | 
| 143 | 
            +
                clip_skip = models[keys.index(name)][2]
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                if torch.cuda.is_available():
         | 
| 146 | 
            +
                    local_unet.to("cuda")
         | 
| 147 | 
            +
                    local_te.to("cuda")
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                if lora_state is not None and lora_state != "":
         | 
| 150 | 
            +
                    local_lora.load(lora_state, lora_scale)
         | 
| 151 | 
            +
                    local_lora.to(local_unet.device, dtype=local_unet.dtype)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                pipe.text_encoder, pipe.unet = local_te, local_unet
         | 
| 154 | 
            +
                pipe.setup_unet(local_unet)
         | 
| 155 | 
            +
                pipe.tokenizer.prepare_for_tokenization = original_prepare_for_tokenization
         | 
| 156 | 
            +
                pipe.tokenizer.added_tokens_encoder = {}
         | 
| 157 | 
            +
                pipe.tokenizer.added_tokens_decoder = {}
         | 
| 158 | 
            +
                pipe.setup_text_encoder(clip_skip, local_te)
         | 
| 159 | 
            +
                return pipe
         | 
| 160 | 
            +
             | 
| 161 | 
            +
             | 
| 162 | 
            +
            def error_str(error, title="Error"):
         | 
| 163 | 
            +
                return (
         | 
| 164 | 
            +
                    f"""#### {title}
         | 
| 165 | 
            +
                        {error}"""
         | 
| 166 | 
            +
                    if error
         | 
| 167 | 
            +
                    else ""
         | 
| 168 | 
            +
                )
         | 
| 169 | 
            +
             | 
| 170 | 
            +
            def make_token_names(embs):
         | 
| 171 | 
            +
                all_tokens = []
         | 
| 172 | 
            +
                for name, vec in embs.items():
         | 
| 173 | 
            +
                    tokens = [f'emb-{name}-{i}' for i in range(len(vec))]
         | 
| 174 | 
            +
                    all_tokens.append(tokens)
         | 
| 175 | 
            +
                return all_tokens
         | 
| 176 | 
            +
             | 
| 177 | 
            +
            def setup_tokenizer(tokenizer, embs):
         | 
| 178 | 
            +
                reg_match = [re.compile(fr"(?:^|(?<=\s|,)){k}(?=,|\s|$)") for k in embs.keys()]
         | 
| 179 | 
            +
                clip_keywords = [' '.join(s) for s in make_token_names(embs)]
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                def parse_prompt(prompt: str):
         | 
| 182 | 
            +
                    for m, v in zip(reg_match, clip_keywords):
         | 
| 183 | 
            +
                        prompt = m.sub(v, prompt)
         | 
| 184 | 
            +
                    return prompt
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                def prepare_for_tokenization(self, text: str, is_split_into_words: bool = False, **kwargs):
         | 
| 187 | 
            +
                    text = parse_prompt(text)
         | 
| 188 | 
            +
                    r = original_prepare_for_tokenization(text, is_split_into_words, **kwargs)
         | 
| 189 | 
            +
                    return r
         | 
| 190 | 
            +
                    tokenizer.prepare_for_tokenization = prepare_for_tokenization.__get__(tokenizer, CLIPTokenizer)
         | 
| 191 | 
            +
                return [t for sublist in make_token_names(embs) for t in sublist]
         | 
| 192 | 
            +
             | 
| 193 | 
            +
             | 
| 194 | 
            +
            def convert_size(size_bytes):
         | 
| 195 | 
            +
                if size_bytes == 0:
         | 
| 196 | 
            +
                    return "0B"
         | 
| 197 | 
            +
                size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
         | 
| 198 | 
            +
                i = int(math.floor(math.log(size_bytes, 1024)))
         | 
| 199 | 
            +
                p = math.pow(1024, i)
         | 
| 200 | 
            +
                s = round(size_bytes / p, 2)
         | 
| 201 | 
            +
                return "%s %s" % (s, size_name[i])
         | 
| 202 | 
            +
             | 
| 203 | 
            +
            def inference(
         | 
| 204 | 
            +
                prompt,
         | 
| 205 | 
            +
                guidance,
         | 
| 206 | 
            +
                steps,
         | 
| 207 | 
            +
                width=512,
         | 
| 208 | 
            +
                height=512,
         | 
| 209 | 
            +
                seed=0,
         | 
| 210 | 
            +
                neg_prompt="",
         | 
| 211 | 
            +
                state=None,
         | 
| 212 | 
            +
                g_strength=0.4,
         | 
| 213 | 
            +
                img_input=None,
         | 
| 214 | 
            +
                i2i_scale=0.5,
         | 
| 215 | 
            +
                hr_enabled=False,
         | 
| 216 | 
            +
                hr_method="Latent",
         | 
| 217 | 
            +
                hr_scale=1.5,
         | 
| 218 | 
            +
                hr_denoise=0.8,
         | 
| 219 | 
            +
                sampler="DPM++ 2M Karras",
         | 
| 220 | 
            +
                embs=None,
         | 
| 221 | 
            +
                model=None,
         | 
| 222 | 
            +
                lora_state=None,
         | 
| 223 | 
            +
                lora_scale=None,
         | 
| 224 | 
            +
            ):
         | 
| 225 | 
            +
                if seed is None or seed == 0:
         | 
| 226 | 
            +
                    seed = random.randint(0, 2147483647)
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                pipe = setup_model(model, lora_state, lora_scale)
         | 
| 229 | 
            +
                generator = torch.Generator("cuda").manual_seed(int(seed))
         | 
| 230 | 
            +
                start_time = time.time()
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                sampler_name, sampler_opt = None, None
         | 
| 233 | 
            +
                for label, funcname, options in samplers_k_diffusion:
         | 
| 234 | 
            +
                    if label == sampler:
         | 
| 235 | 
            +
                        sampler_name, sampler_opt = funcname, options
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                tokenizer, text_encoder = pipe.tokenizer, pipe.text_encoder
         | 
| 238 | 
            +
                if embs is not None and len(embs) > 0:
         | 
| 239 | 
            +
                    ti_embs = {}
         | 
| 240 | 
            +
                    for name, file in embs.items():
         | 
| 241 | 
            +
                        if str(file).endswith(".pt"):
         | 
| 242 | 
            +
                            loaded_learned_embeds = torch.load(file, map_location="cpu")
         | 
| 243 | 
            +
                        else:
         | 
| 244 | 
            +
                            loaded_learned_embeds = load_file(file, device="cpu")
         | 
| 245 | 
            +
                        loaded_learned_embeds = loaded_learned_embeds["string_to_param"]["*"] if "string_to_param" in loaded_learned_embed else loaded_learned_embed
         | 
| 246 | 
            +
                        ti_embs[name] = loaded_learned_embeds
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                    if len(ti_embs) > 0:
         | 
| 249 | 
            +
                        tokens = setup_tokenizer(tokenizer, ti_embs)
         | 
| 250 | 
            +
                        added_tokens = tokenizer.add_tokens(tokens)
         | 
| 251 | 
            +
                        delta_weight = torch.cat([val for val in ti_embs.values()], dim=0)
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                        assert added_tokens == delta_weight.shape[0]
         | 
| 254 | 
            +
                        text_encoder.resize_token_embeddings(len(tokenizer))
         | 
| 255 | 
            +
                        token_embeds = text_encoder.get_input_embeddings().weight.data
         | 
| 256 | 
            +
                        token_embeds[-delta_weight.shape[0]:] = delta_weight
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                config = {
         | 
| 259 | 
            +
                    "negative_prompt": neg_prompt,
         | 
| 260 | 
            +
                    "num_inference_steps": int(steps),
         | 
| 261 | 
            +
                    "guidance_scale": guidance,
         | 
| 262 | 
            +
                    "generator": generator,
         | 
| 263 | 
            +
                    "sampler_name": sampler_name,
         | 
| 264 | 
            +
                    "sampler_opt": sampler_opt,
         | 
| 265 | 
            +
                    "pww_state": state,
         | 
| 266 | 
            +
                    "pww_attn_weight": g_strength,
         | 
| 267 | 
            +
                    "start_time": start_time,
         | 
| 268 | 
            +
                    "timeout": timeout,
         | 
| 269 | 
            +
                }
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                if img_input is not None:
         | 
| 272 | 
            +
                    ratio = min(height / img_input.height, width / img_input.width)
         | 
| 273 | 
            +
                    img_input = img_input.resize(
         | 
| 274 | 
            +
                        (int(img_input.width * ratio), int(img_input.height * ratio)), Image.LANCZOS
         | 
| 275 | 
            +
                    )
         | 
| 276 | 
            +
                    result = pipe.img2img(prompt, image=img_input, strength=i2i_scale, **config)
         | 
| 277 | 
            +
                elif hr_enabled:
         | 
| 278 | 
            +
                    result = pipe.txt2img(
         | 
| 279 | 
            +
                        prompt,
         | 
| 280 | 
            +
                        width=width,
         | 
| 281 | 
            +
                        height=height,
         | 
| 282 | 
            +
                        upscale=True,
         | 
| 283 | 
            +
                        upscale_x=hr_scale,
         | 
| 284 | 
            +
                        upscale_denoising_strength=hr_denoise,
         | 
| 285 | 
            +
                        **config,
         | 
| 286 | 
            +
                        **latent_upscale_modes[hr_method],
         | 
| 287 | 
            +
                    )
         | 
| 288 | 
            +
                else:
         | 
| 289 | 
            +
                    result = pipe.txt2img(prompt, width=width, height=height, **config)
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                end_time = time.time()
         | 
| 292 | 
            +
                vram_free, vram_total = torch.cuda.mem_get_info()
         | 
| 293 | 
            +
                print(f"done: model={model}, res={width}x{height}, step={steps}, time={round(end_time-start_time, 2)}s, vram_alloc={convert_size(vram_total-vram_free)}/{convert_size(vram_total)}")
         | 
| 294 | 
            +
                return gr.Image.update(result[0][0], label=f"Initial Seed: {seed}")
         | 
| 295 | 
            +
             | 
| 296 | 
            +
             | 
| 297 | 
            +
            color_list = []
         | 
| 298 | 
            +
             | 
| 299 | 
            +
             | 
| 300 | 
            +
            def get_color(n):
         | 
| 301 | 
            +
                for _ in range(n - len(color_list)):
         | 
| 302 | 
            +
                    color_list.append(tuple(np.random.random(size=3) * 256))
         | 
| 303 | 
            +
                return color_list
         | 
| 304 | 
            +
             | 
| 305 | 
            +
             | 
| 306 | 
            +
            def create_mixed_img(current, state, w=512, h=512):
         | 
| 307 | 
            +
                w, h = int(w), int(h)
         | 
| 308 | 
            +
                image_np = np.full([h, w, 4], 255)
         | 
| 309 | 
            +
                if state is None:
         | 
| 310 | 
            +
                    state = {}
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                colors = get_color(len(state))
         | 
| 313 | 
            +
                idx = 0
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                for key, item in state.items():
         | 
| 316 | 
            +
                    if item["map"] is not None:
         | 
| 317 | 
            +
                        m = item["map"] < 255
         | 
| 318 | 
            +
                        alpha = 150
         | 
| 319 | 
            +
                        if current == key:
         | 
| 320 | 
            +
                            alpha = 200
         | 
| 321 | 
            +
                        image_np[m] = colors[idx] + (alpha,)
         | 
| 322 | 
            +
                    idx += 1
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                return image_np
         | 
| 325 | 
            +
             | 
| 326 | 
            +
             | 
| 327 | 
            +
            # width.change(apply_new_res, inputs=[width, height, global_stats], outputs=[global_stats, sp, rendered])
         | 
| 328 | 
            +
            def apply_new_res(w, h, state):
         | 
| 329 | 
            +
                w, h = int(w), int(h)
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                for key, item in state.items():
         | 
| 332 | 
            +
                    if item["map"] is not None:
         | 
| 333 | 
            +
                        item["map"] = resize(item["map"], w, h)
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                update_img = gr.Image.update(value=create_mixed_img("", state, w, h))
         | 
| 336 | 
            +
                return state, update_img
         | 
| 337 | 
            +
             | 
| 338 | 
            +
             | 
| 339 | 
            +
            def detect_text(text, state, width, height):
         | 
| 340 | 
            +
                
         | 
| 341 | 
            +
                if text is None or text == "":
         | 
| 342 | 
            +
                    return None, None, gr.Radio.update(value=None), None
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                t = text.split(",")
         | 
| 345 | 
            +
                new_state = {}
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                for item in t:
         | 
| 348 | 
            +
                    item = item.strip()
         | 
| 349 | 
            +
                    if item == "":
         | 
| 350 | 
            +
                        continue
         | 
| 351 | 
            +
                    if state is not None and item in state:
         | 
| 352 | 
            +
                        new_state[item] = {
         | 
| 353 | 
            +
                            "map": state[item]["map"],
         | 
| 354 | 
            +
                            "weight": state[item]["weight"],
         | 
| 355 | 
            +
                            "mask_outsides": state[item]["mask_outsides"],
         | 
| 356 | 
            +
                        }
         | 
| 357 | 
            +
                    else:
         | 
| 358 | 
            +
                        new_state[item] = {
         | 
| 359 | 
            +
                            "map": None,
         | 
| 360 | 
            +
                            "weight": 0.5,
         | 
| 361 | 
            +
                            "mask_outsides": False
         | 
| 362 | 
            +
                        }
         | 
| 363 | 
            +
                update = gr.Radio.update(choices=[key for key in new_state.keys()], value=None)
         | 
| 364 | 
            +
                update_img = gr.update(value=create_mixed_img("", new_state, width, height))
         | 
| 365 | 
            +
                update_sketch = gr.update(value=None, interactive=False)
         | 
| 366 | 
            +
                return new_state, update_sketch, update, update_img
         | 
| 367 | 
            +
             | 
| 368 | 
            +
             | 
| 369 | 
            +
            def resize(img, w, h):
         | 
| 370 | 
            +
                trs = transforms.Compose(
         | 
| 371 | 
            +
                    [
         | 
| 372 | 
            +
                        transforms.ToPILImage(),
         | 
| 373 | 
            +
                        transforms.Resize(min(h, w)),
         | 
| 374 | 
            +
                        transforms.CenterCrop((h, w)),
         | 
| 375 | 
            +
                    ]
         | 
| 376 | 
            +
                )
         | 
| 377 | 
            +
                result = np.array(trs(img), dtype=np.uint8)
         | 
| 378 | 
            +
                return result
         | 
| 379 | 
            +
             | 
| 380 | 
            +
             | 
| 381 | 
            +
            def switch_canvas(entry, state, width, height):
         | 
| 382 | 
            +
                if entry == None:
         | 
| 383 | 
            +
                    return None, 0.5, False, create_mixed_img("", state, width, height)
         | 
| 384 | 
            +
             | 
| 385 | 
            +
                return (
         | 
| 386 | 
            +
                    gr.update(value=None, interactive=True),
         | 
| 387 | 
            +
                    gr.update(value=state[entry]["weight"] if entry in state else 0.5),
         | 
| 388 | 
            +
                    gr.update(value=state[entry]["mask_outsides"] if entry in state else False),
         | 
| 389 | 
            +
                    create_mixed_img(entry, state, width, height),
         | 
| 390 | 
            +
                )
         | 
| 391 | 
            +
             | 
| 392 | 
            +
             | 
| 393 | 
            +
            def apply_canvas(selected, draw, state, w, h):
         | 
| 394 | 
            +
                if selected in state:
         | 
| 395 | 
            +
                    w, h = int(w), int(h)
         | 
| 396 | 
            +
                    state[selected]["map"] = resize(draw, w, h)
         | 
| 397 | 
            +
                return state, gr.Image.update(value=create_mixed_img(selected, state, w, h))
         | 
| 398 | 
            +
             | 
| 399 | 
            +
             | 
| 400 | 
            +
            def apply_weight(selected, weight, state):
         | 
| 401 | 
            +
                if selected in state:
         | 
| 402 | 
            +
                    state[selected]["weight"] = weight
         | 
| 403 | 
            +
                return state
         | 
| 404 | 
            +
             | 
| 405 | 
            +
             | 
| 406 | 
            +
            def apply_option(selected, mask, state):
         | 
| 407 | 
            +
                if selected in state:
         | 
| 408 | 
            +
                    state[selected]["mask_outsides"] = mask
         | 
| 409 | 
            +
                return state
         | 
| 410 | 
            +
             | 
| 411 | 
            +
             | 
| 412 | 
            +
            # sp2, radio, width, height, global_stats
         | 
| 413 | 
            +
            def apply_image(image, selected, w, h, strgength, mask, state):
         | 
| 414 | 
            +
                if selected in state:
         | 
| 415 | 
            +
                    state[selected] = {
         | 
| 416 | 
            +
                        "map": resize(image, w, h), 
         | 
| 417 | 
            +
                        "weight": strgength, 
         | 
| 418 | 
            +
                        "mask_outsides": mask
         | 
| 419 | 
            +
                    }
         | 
| 420 | 
            +
                    
         | 
| 421 | 
            +
                return state, gr.Image.update(value=create_mixed_img(selected, state, w, h))
         | 
| 422 | 
            +
             | 
| 423 | 
            +
             | 
| 424 | 
            +
            # [ti_state, lora_state, ti_vals, lora_vals, uploads]
         | 
| 425 | 
            +
            def add_net(files, ti_state, lora_state):
         | 
| 426 | 
            +
                if files is None:
         | 
| 427 | 
            +
                    return ti_state, "", lora_state, None
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                for file in files:
         | 
| 430 | 
            +
                    item = Path(file.name)
         | 
| 431 | 
            +
                    stripedname = str(item.stem).strip()
         | 
| 432 | 
            +
                    if item.suffix == ".pt":
         | 
| 433 | 
            +
                        state_dict = torch.load(file.name, map_location="cpu")
         | 
| 434 | 
            +
                    else:
         | 
| 435 | 
            +
                        state_dict = load_file(file.name, device="cpu")
         | 
| 436 | 
            +
                    if any("lora" in k for k in state_dict.keys()):
         | 
| 437 | 
            +
                        lora_state = file.name
         | 
| 438 | 
            +
                    else:
         | 
| 439 | 
            +
                        ti_state[stripedname] = file.name
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                return (
         | 
| 442 | 
            +
                    ti_state,
         | 
| 443 | 
            +
                    lora_state,
         | 
| 444 | 
            +
                    gr.Text.update(f"{[key for key in ti_state.keys()]}"),
         | 
| 445 | 
            +
                    gr.Text.update(f"{lora_state}"),
         | 
| 446 | 
            +
                    gr.Files.update(value=None),
         | 
| 447 | 
            +
                )
         | 
| 448 | 
            +
             | 
| 449 | 
            +
             | 
| 450 | 
            +
            # [ti_state, lora_state, ti_vals, lora_vals, uploads]
         | 
| 451 | 
            +
            def clean_states(ti_state, lora_state):
         | 
| 452 | 
            +
                return (
         | 
| 453 | 
            +
                    dict(),
         | 
| 454 | 
            +
                    None,
         | 
| 455 | 
            +
                    gr.Text.update(f""),
         | 
| 456 | 
            +
                    gr.Text.update(f""),
         | 
| 457 | 
            +
                    gr.File.update(value=None),
         | 
| 458 | 
            +
                )
         | 
| 459 | 
            +
             | 
| 460 | 
            +
             | 
| 461 | 
            +
            latent_upscale_modes = {
         | 
| 462 | 
            +
                "Latent": {"upscale_method": "bilinear", "upscale_antialias": False},
         | 
| 463 | 
            +
                "Latent (antialiased)": {"upscale_method": "bilinear", "upscale_antialias": True},
         | 
| 464 | 
            +
                "Latent (bicubic)": {"upscale_method": "bicubic", "upscale_antialias": False},
         | 
| 465 | 
            +
                "Latent (bicubic antialiased)": {
         | 
| 466 | 
            +
                    "upscale_method": "bicubic",
         | 
| 467 | 
            +
                    "upscale_antialias": True,
         | 
| 468 | 
            +
                },
         | 
| 469 | 
            +
                "Latent (nearest)": {"upscale_method": "nearest", "upscale_antialias": False},
         | 
| 470 | 
            +
                "Latent (nearest-exact)": {
         | 
| 471 | 
            +
                    "upscale_method": "nearest-exact",
         | 
| 472 | 
            +
                    "upscale_antialias": False,
         | 
| 473 | 
            +
                },
         | 
| 474 | 
            +
            }
         | 
| 475 | 
            +
             | 
| 476 | 
            +
            css = """
         | 
| 477 | 
            +
            .finetuned-diffusion-div div{
         | 
| 478 | 
            +
                display:inline-flex;
         | 
| 479 | 
            +
                align-items:center;
         | 
| 480 | 
            +
                gap:.8rem;
         | 
| 481 | 
            +
                font-size:1.75rem;
         | 
| 482 | 
            +
                padding-top:2rem;
         | 
| 483 | 
            +
            }
         | 
| 484 | 
            +
            .finetuned-diffusion-div div h1{
         | 
| 485 | 
            +
                font-weight:900;
         | 
| 486 | 
            +
                margin-bottom:7px
         | 
| 487 | 
            +
            }
         | 
| 488 | 
            +
            .finetuned-diffusion-div p{
         | 
| 489 | 
            +
                margin-bottom:10px;
         | 
| 490 | 
            +
                font-size:94%
         | 
| 491 | 
            +
            }
         | 
| 492 | 
            +
            .box {
         | 
| 493 | 
            +
              float: left;
         | 
| 494 | 
            +
              height: 20px;
         | 
| 495 | 
            +
              width: 20px;
         | 
| 496 | 
            +
              margin-bottom: 15px;
         | 
| 497 | 
            +
              border: 1px solid black;
         | 
| 498 | 
            +
              clear: both;
         | 
| 499 | 
            +
            }
         | 
| 500 | 
            +
            a{
         | 
| 501 | 
            +
                text-decoration:underline
         | 
| 502 | 
            +
            }
         | 
| 503 | 
            +
            .tabs{
         | 
| 504 | 
            +
                margin-top:0;
         | 
| 505 | 
            +
                margin-bottom:0
         | 
| 506 | 
            +
            }
         | 
| 507 | 
            +
            #gallery{
         | 
| 508 | 
            +
                min-height:20rem
         | 
| 509 | 
            +
            }
         | 
| 510 | 
            +
            .no-border {
         | 
| 511 | 
            +
                border: none !important;
         | 
| 512 | 
            +
            }
         | 
| 513 | 
            +
             """
         | 
| 514 | 
            +
            with gr.Blocks(css=css) as demo:
         | 
| 515 | 
            +
                gr.HTML(
         | 
| 516 | 
            +
                    f"""
         | 
| 517 | 
            +
                        <div class="finetuned-diffusion-div">
         | 
| 518 | 
            +
                          <div>
         | 
| 519 | 
            +
                            <h1>Demo for diffusion models</h1>
         | 
| 520 | 
            +
                          </div>
         | 
| 521 | 
            +
                          <p>Hso @ nyanko.sketch2img.gradio</p>
         | 
| 522 | 
            +
                        </div>
         | 
| 523 | 
            +
                    """
         | 
| 524 | 
            +
                )
         | 
| 525 | 
            +
                global_stats = gr.State(value={})
         | 
| 526 | 
            +
             | 
| 527 | 
            +
                with gr.Row():
         | 
| 528 | 
            +
             | 
| 529 | 
            +
                    with gr.Column(scale=55):
         | 
| 530 | 
            +
                        model = gr.Dropdown(
         | 
| 531 | 
            +
                            choices=[k[0] for k in get_model_list()],
         | 
| 532 | 
            +
                            label="Model",
         | 
| 533 | 
            +
                            value=base_name,
         | 
| 534 | 
            +
                        )
         | 
| 535 | 
            +
                        image_out = gr.Image(height=512)
         | 
| 536 | 
            +
                    # gallery = gr.Gallery(
         | 
| 537 | 
            +
                    #     label="Generated images", show_label=False, elem_id="gallery"
         | 
| 538 | 
            +
                    # ).style(grid=[1], height="auto")
         | 
| 539 | 
            +
             | 
| 540 | 
            +
                    with gr.Column(scale=45):
         | 
| 541 | 
            +
             | 
| 542 | 
            +
                        with gr.Group():
         | 
| 543 | 
            +
             | 
| 544 | 
            +
                            with gr.Row():
         | 
| 545 | 
            +
                                with gr.Column(scale=70):
         | 
| 546 | 
            +
             | 
| 547 | 
            +
                                    prompt = gr.Textbox(
         | 
| 548 | 
            +
                                        label="Prompt",
         | 
| 549 | 
            +
                                        value="loli cat girl, blue eyes, flat chest, solo, long messy silver hair, blue capelet, cat ears, cat tail, upper body",
         | 
| 550 | 
            +
                                        show_label=True,
         | 
| 551 | 
            +
                                        max_lines=4,
         | 
| 552 | 
            +
                                        placeholder="Enter prompt.",
         | 
| 553 | 
            +
                                    )
         | 
| 554 | 
            +
                                    neg_prompt = gr.Textbox(
         | 
| 555 | 
            +
                                        label="Negative Prompt",
         | 
| 556 | 
            +
                                        value="bad quality, low quality, jpeg artifact, cropped",
         | 
| 557 | 
            +
                                        show_label=True,
         | 
| 558 | 
            +
                                        max_lines=4,
         | 
| 559 | 
            +
                                        placeholder="Enter negative prompt.",
         | 
| 560 | 
            +
                                    )
         | 
| 561 | 
            +
             | 
| 562 | 
            +
                                generate = gr.Button(value="Generate").style(
         | 
| 563 | 
            +
                                    rounded=(False, True, True, False)
         | 
| 564 | 
            +
                                )
         | 
| 565 | 
            +
             | 
| 566 | 
            +
                        with gr.Tab("Options"):
         | 
| 567 | 
            +
             | 
| 568 | 
            +
                            with gr.Group():
         | 
| 569 | 
            +
             | 
| 570 | 
            +
                                # n_images = gr.Slider(label="Images", value=1, minimum=1, maximum=4, step=1)
         | 
| 571 | 
            +
                                with gr.Row():
         | 
| 572 | 
            +
                                    guidance = gr.Slider(
         | 
| 573 | 
            +
                                        label="Guidance scale", value=7.5, maximum=15
         | 
| 574 | 
            +
                                    )
         | 
| 575 | 
            +
                                    steps = gr.Slider(
         | 
| 576 | 
            +
                                        label="Steps", value=25, minimum=2, maximum=50, step=1
         | 
| 577 | 
            +
                                    )
         | 
| 578 | 
            +
             | 
| 579 | 
            +
                                with gr.Row():
         | 
| 580 | 
            +
                                    width = gr.Slider(
         | 
| 581 | 
            +
                                        label="Width", value=512, minimum=64, maximum=768, step=64
         | 
| 582 | 
            +
                                    )
         | 
| 583 | 
            +
                                    height = gr.Slider(
         | 
| 584 | 
            +
                                        label="Height", value=512, minimum=64, maximum=768, step=64
         | 
| 585 | 
            +
                                    )
         | 
| 586 | 
            +
             | 
| 587 | 
            +
                                sampler = gr.Dropdown(
         | 
| 588 | 
            +
                                    value="DPM++ 2M Karras",
         | 
| 589 | 
            +
                                    label="Sampler",
         | 
| 590 | 
            +
                                    choices=[s[0] for s in samplers_k_diffusion],
         | 
| 591 | 
            +
                                )
         | 
| 592 | 
            +
                                seed = gr.Number(label="Seed (0 = random)", value=0)
         | 
| 593 | 
            +
             | 
| 594 | 
            +
                        with gr.Tab("Image to image"):
         | 
| 595 | 
            +
                            with gr.Group():
         | 
| 596 | 
            +
             | 
| 597 | 
            +
                                inf_image = gr.Image(
         | 
| 598 | 
            +
                                    label="Image", height=256, tool="editor", type="pil"
         | 
| 599 | 
            +
                                )
         | 
| 600 | 
            +
                                inf_strength = gr.Slider(
         | 
| 601 | 
            +
                                    label="Transformation strength",
         | 
| 602 | 
            +
                                    minimum=0,
         | 
| 603 | 
            +
                                    maximum=1,
         | 
| 604 | 
            +
                                    step=0.01,
         | 
| 605 | 
            +
                                    value=0.5,
         | 
| 606 | 
            +
                                )
         | 
| 607 | 
            +
             | 
| 608 | 
            +
                        def res_cap(g, w, h, x):
         | 
| 609 | 
            +
                            if g:
         | 
| 610 | 
            +
                                return f"Enable upscaler: {w}x{h} to {int(w*x)}x{int(h*x)}"
         | 
| 611 | 
            +
                            else:
         | 
| 612 | 
            +
                                return "Enable upscaler"
         | 
| 613 | 
            +
             | 
| 614 | 
            +
                        with gr.Tab("Hires fix"):
         | 
| 615 | 
            +
                            with gr.Group():
         | 
| 616 | 
            +
             | 
| 617 | 
            +
                                hr_enabled = gr.Checkbox(label="Enable upscaler", value=False)
         | 
| 618 | 
            +
                                hr_method = gr.Dropdown(
         | 
| 619 | 
            +
                                    [key for key in latent_upscale_modes.keys()],
         | 
| 620 | 
            +
                                    value="Latent",
         | 
| 621 | 
            +
                                    label="Upscale method",
         | 
| 622 | 
            +
                                )
         | 
| 623 | 
            +
                                hr_scale = gr.Slider(
         | 
| 624 | 
            +
                                    label="Upscale factor",
         | 
| 625 | 
            +
                                    minimum=1.0,
         | 
| 626 | 
            +
                                    maximum=1.5,
         | 
| 627 | 
            +
                                    step=0.1,
         | 
| 628 | 
            +
                                    value=1.2,
         | 
| 629 | 
            +
                                )
         | 
| 630 | 
            +
                                hr_denoise = gr.Slider(
         | 
| 631 | 
            +
                                    label="Denoising strength",
         | 
| 632 | 
            +
                                    minimum=0.0,
         | 
| 633 | 
            +
                                    maximum=1.0,
         | 
| 634 | 
            +
                                    step=0.1,
         | 
| 635 | 
            +
                                    value=0.8,
         | 
| 636 | 
            +
                                )
         | 
| 637 | 
            +
             | 
| 638 | 
            +
                                hr_scale.change(
         | 
| 639 | 
            +
                                    lambda g, x, w, h: gr.Checkbox.update(
         | 
| 640 | 
            +
                                        label=res_cap(g, w, h, x)
         | 
| 641 | 
            +
                                    ),
         | 
| 642 | 
            +
                                    inputs=[hr_enabled, hr_scale, width, height],
         | 
| 643 | 
            +
                                    outputs=hr_enabled,
         | 
| 644 | 
            +
                                    queue=False,
         | 
| 645 | 
            +
                                )
         | 
| 646 | 
            +
                                hr_enabled.change(
         | 
| 647 | 
            +
                                    lambda g, x, w, h: gr.Checkbox.update(
         | 
| 648 | 
            +
                                        label=res_cap(g, w, h, x)
         | 
| 649 | 
            +
                                    ),
         | 
| 650 | 
            +
                                    inputs=[hr_enabled, hr_scale, width, height],
         | 
| 651 | 
            +
                                    outputs=hr_enabled,
         | 
| 652 | 
            +
                                    queue=False,
         | 
| 653 | 
            +
                                )
         | 
| 654 | 
            +
             | 
| 655 | 
            +
                        with gr.Tab("Embeddings/Loras"):
         | 
| 656 | 
            +
             | 
| 657 | 
            +
                            ti_state = gr.State(dict())
         | 
| 658 | 
            +
                            lora_state = gr.State()
         | 
| 659 | 
            +
             | 
| 660 | 
            +
                            with gr.Group():
         | 
| 661 | 
            +
                                with gr.Row():
         | 
| 662 | 
            +
                                    with gr.Column(scale=90):
         | 
| 663 | 
            +
                                        ti_vals = gr.Text(label="Loaded embeddings")
         | 
| 664 | 
            +
             | 
| 665 | 
            +
                                with gr.Row():
         | 
| 666 | 
            +
                                    with gr.Column(scale=90):
         | 
| 667 | 
            +
                                        lora_vals = gr.Text(label="Loaded loras")
         | 
| 668 | 
            +
             | 
| 669 | 
            +
                            with gr.Row():
         | 
| 670 | 
            +
             | 
| 671 | 
            +
                                uploads = gr.Files(label="Upload new embeddings/lora")
         | 
| 672 | 
            +
             | 
| 673 | 
            +
                                with gr.Column():
         | 
| 674 | 
            +
                                    lora_scale = gr.Slider(
         | 
| 675 | 
            +
                                        label="Lora scale",
         | 
| 676 | 
            +
                                        minimum=0,
         | 
| 677 | 
            +
                                        maximum=2,
         | 
| 678 | 
            +
                                        step=0.01,
         | 
| 679 | 
            +
                                        value=1.0,
         | 
| 680 | 
            +
                                    )
         | 
| 681 | 
            +
                                    btn = gr.Button(value="Upload")
         | 
| 682 | 
            +
                                    btn_del = gr.Button(value="Reset")
         | 
| 683 | 
            +
             | 
| 684 | 
            +
                            btn.click(
         | 
| 685 | 
            +
                                add_net,
         | 
| 686 | 
            +
                                inputs=[uploads, ti_state, lora_state],
         | 
| 687 | 
            +
                                outputs=[ti_state, lora_state, ti_vals, lora_vals, uploads],
         | 
| 688 | 
            +
                                queue=False,
         | 
| 689 | 
            +
                            )
         | 
| 690 | 
            +
                            btn_del.click(
         | 
| 691 | 
            +
                                clean_states,
         | 
| 692 | 
            +
                                inputs=[ti_state, lora_state],
         | 
| 693 | 
            +
                                outputs=[ti_state, lora_state, ti_vals, lora_vals, uploads],
         | 
| 694 | 
            +
                                queue=False,
         | 
| 695 | 
            +
                            )
         | 
| 696 | 
            +
             | 
| 697 | 
            +
                    # error_output = gr.Markdown()
         | 
| 698 | 
            +
             | 
| 699 | 
            +
                gr.HTML(
         | 
| 700 | 
            +
                    f"""
         | 
| 701 | 
            +
                        <div class="finetuned-diffusion-div">
         | 
| 702 | 
            +
                          <div>
         | 
| 703 | 
            +
                            <h1>Paint with words</h1>
         | 
| 704 | 
            +
                          </div>
         | 
| 705 | 
            +
                          <p>
         | 
| 706 | 
            +
                            Will use the following formula: w = scale * token_weight_martix * log(1 + sigma) * max(qk).
         | 
| 707 | 
            +
                          </p>
         | 
| 708 | 
            +
                        </div>
         | 
| 709 | 
            +
                    """
         | 
| 710 | 
            +
                )
         | 
| 711 | 
            +
             | 
| 712 | 
            +
                with gr.Row():
         | 
| 713 | 
            +
             | 
| 714 | 
            +
                    with gr.Column(scale=55):
         | 
| 715 | 
            +
             | 
| 716 | 
            +
                        rendered = gr.Image(
         | 
| 717 | 
            +
                            invert_colors=True,
         | 
| 718 | 
            +
                            source="canvas",
         | 
| 719 | 
            +
                            interactive=False,
         | 
| 720 | 
            +
                            image_mode="RGBA",
         | 
| 721 | 
            +
                        )
         | 
| 722 | 
            +
             | 
| 723 | 
            +
                    with gr.Column(scale=45):
         | 
| 724 | 
            +
             | 
| 725 | 
            +
                        with gr.Group():
         | 
| 726 | 
            +
                            with gr.Row():
         | 
| 727 | 
            +
                                with gr.Column(scale=70):
         | 
| 728 | 
            +
                                    g_strength = gr.Slider(
         | 
| 729 | 
            +
                                        label="Weight scaling",
         | 
| 730 | 
            +
                                        minimum=0,
         | 
| 731 | 
            +
                                        maximum=0.8,
         | 
| 732 | 
            +
                                        step=0.01,
         | 
| 733 | 
            +
                                        value=0.4,
         | 
| 734 | 
            +
                                    )
         | 
| 735 | 
            +
             | 
| 736 | 
            +
                                    text = gr.Textbox(
         | 
| 737 | 
            +
                                        lines=2,
         | 
| 738 | 
            +
                                        interactive=True,
         | 
| 739 | 
            +
                                        label="Token to Draw: (Separate by comma)",
         | 
| 740 | 
            +
                                    )
         | 
| 741 | 
            +
             | 
| 742 | 
            +
                                    radio = gr.Radio([], label="Tokens")
         | 
| 743 | 
            +
             | 
| 744 | 
            +
                                sk_update = gr.Button(value="Update").style(
         | 
| 745 | 
            +
                                    rounded=(False, True, True, False)
         | 
| 746 | 
            +
                                )
         | 
| 747 | 
            +
             | 
| 748 | 
            +
                            # g_strength.change(lambda b: gr.update(f"Scaled additional attn: $w = {b} \log (1 + \sigma) \std (Q^T K)$."), inputs=g_strength, outputs=[g_output])
         | 
| 749 | 
            +
             | 
| 750 | 
            +
                        with gr.Tab("SketchPad"):
         | 
| 751 | 
            +
             | 
| 752 | 
            +
                            sp = gr.Image(
         | 
| 753 | 
            +
                                image_mode="L",
         | 
| 754 | 
            +
                                tool="sketch",
         | 
| 755 | 
            +
                                source="canvas",
         | 
| 756 | 
            +
                                interactive=False,
         | 
| 757 | 
            +
                            )
         | 
| 758 | 
            +
             | 
| 759 | 
            +
                            mask_outsides = gr.Checkbox(
         | 
| 760 | 
            +
                                label="Mask other areas", 
         | 
| 761 | 
            +
                                value=False
         | 
| 762 | 
            +
                            )
         | 
| 763 | 
            +
             | 
| 764 | 
            +
                            strength = gr.Slider(
         | 
| 765 | 
            +
                                label="Token strength",
         | 
| 766 | 
            +
                                minimum=0,
         | 
| 767 | 
            +
                                maximum=0.8,
         | 
| 768 | 
            +
                                step=0.01,
         | 
| 769 | 
            +
                                value=0.5,
         | 
| 770 | 
            +
                            )
         | 
| 771 | 
            +
             | 
| 772 | 
            +
             | 
| 773 | 
            +
                            sk_update.click(
         | 
| 774 | 
            +
                                detect_text,
         | 
| 775 | 
            +
                                inputs=[text, global_stats, width, height],
         | 
| 776 | 
            +
                                outputs=[global_stats, sp, radio, rendered],
         | 
| 777 | 
            +
                                queue=False,
         | 
| 778 | 
            +
                            )
         | 
| 779 | 
            +
                            radio.change(
         | 
| 780 | 
            +
                                switch_canvas,
         | 
| 781 | 
            +
                                inputs=[radio, global_stats, width, height],
         | 
| 782 | 
            +
                                outputs=[sp, strength, mask_outsides, rendered],
         | 
| 783 | 
            +
                                queue=False,
         | 
| 784 | 
            +
                            )
         | 
| 785 | 
            +
                            sp.edit(
         | 
| 786 | 
            +
                                apply_canvas,
         | 
| 787 | 
            +
                                inputs=[radio, sp, global_stats, width, height],
         | 
| 788 | 
            +
                                outputs=[global_stats, rendered],
         | 
| 789 | 
            +
                                queue=False,
         | 
| 790 | 
            +
                            )
         | 
| 791 | 
            +
                            strength.change(
         | 
| 792 | 
            +
                                apply_weight,
         | 
| 793 | 
            +
                                inputs=[radio, strength, global_stats],
         | 
| 794 | 
            +
                                outputs=[global_stats],
         | 
| 795 | 
            +
                                queue=False,
         | 
| 796 | 
            +
                            )
         | 
| 797 | 
            +
                            mask_outsides.change(
         | 
| 798 | 
            +
                                apply_option,
         | 
| 799 | 
            +
                                inputs=[radio, mask_outsides, global_stats],
         | 
| 800 | 
            +
                                outputs=[global_stats],
         | 
| 801 | 
            +
                                queue=False,
         | 
| 802 | 
            +
                            )
         | 
| 803 | 
            +
             | 
| 804 | 
            +
                        with gr.Tab("UploadFile"):
         | 
| 805 | 
            +
             | 
| 806 | 
            +
                            sp2 = gr.Image(
         | 
| 807 | 
            +
                                image_mode="L",
         | 
| 808 | 
            +
                                source="upload",
         | 
| 809 | 
            +
                                shape=(512, 512),
         | 
| 810 | 
            +
                            )
         | 
| 811 | 
            +
             | 
| 812 | 
            +
                            mask_outsides2 = gr.Checkbox(
         | 
| 813 | 
            +
                                label="Mask other areas", 
         | 
| 814 | 
            +
                                value=False,
         | 
| 815 | 
            +
                            )
         | 
| 816 | 
            +
             | 
| 817 | 
            +
                            strength2 = gr.Slider(
         | 
| 818 | 
            +
                                label="Token strength",
         | 
| 819 | 
            +
                                minimum=0,
         | 
| 820 | 
            +
                                maximum=0.8,
         | 
| 821 | 
            +
                                step=0.01,
         | 
| 822 | 
            +
                                value=0.5,
         | 
| 823 | 
            +
                            )
         | 
| 824 | 
            +
             | 
| 825 | 
            +
                            apply_style = gr.Button(value="Apply")
         | 
| 826 | 
            +
                            apply_style.click(
         | 
| 827 | 
            +
                                apply_image,
         | 
| 828 | 
            +
                                inputs=[sp2, radio, width, height, strength2, mask_outsides2, global_stats],
         | 
| 829 | 
            +
                                outputs=[global_stats, rendered],
         | 
| 830 | 
            +
                                queue=False,
         | 
| 831 | 
            +
                            )
         | 
| 832 | 
            +
             | 
| 833 | 
            +
                        width.change(
         | 
| 834 | 
            +
                            apply_new_res,
         | 
| 835 | 
            +
                            inputs=[width, height, global_stats],
         | 
| 836 | 
            +
                            outputs=[global_stats, rendered],
         | 
| 837 | 
            +
                            queue=False,
         | 
| 838 | 
            +
                        )
         | 
| 839 | 
            +
                        height.change(
         | 
| 840 | 
            +
                            apply_new_res,
         | 
| 841 | 
            +
                            inputs=[width, height, global_stats],
         | 
| 842 | 
            +
                            outputs=[global_stats, rendered],
         | 
| 843 | 
            +
                            queue=False,
         | 
| 844 | 
            +
                        )
         | 
| 845 | 
            +
             | 
| 846 | 
            +
                # color_stats = gr.State(value={})
         | 
| 847 | 
            +
                # text.change(detect_color, inputs=[sp, text, color_stats], outputs=[color_stats, rendered])
         | 
| 848 | 
            +
                # sp.change(detect_color, inputs=[sp, text, color_stats], outputs=[color_stats, rendered])
         | 
| 849 | 
            +
             | 
| 850 | 
            +
                inputs = [
         | 
| 851 | 
            +
                    prompt,
         | 
| 852 | 
            +
                    guidance,
         | 
| 853 | 
            +
                    steps,
         | 
| 854 | 
            +
                    width,
         | 
| 855 | 
            +
                    height,
         | 
| 856 | 
            +
                    seed,
         | 
| 857 | 
            +
                    neg_prompt,
         | 
| 858 | 
            +
                    global_stats,
         | 
| 859 | 
            +
                    g_strength,
         | 
| 860 | 
            +
                    inf_image,
         | 
| 861 | 
            +
                    inf_strength,
         | 
| 862 | 
            +
                    hr_enabled,
         | 
| 863 | 
            +
                    hr_method,
         | 
| 864 | 
            +
                    hr_scale,
         | 
| 865 | 
            +
                    hr_denoise,
         | 
| 866 | 
            +
                    sampler,
         | 
| 867 | 
            +
                    ti_state,
         | 
| 868 | 
            +
                    model,
         | 
| 869 | 
            +
                    lora_state,
         | 
| 870 | 
            +
                    lora_scale,
         | 
| 871 | 
            +
                ]
         | 
| 872 | 
            +
                outputs = [image_out]
         | 
| 873 | 
            +
                prompt.submit(inference, inputs=inputs, outputs=outputs)
         | 
| 874 | 
            +
                generate.click(inference, inputs=inputs, outputs=outputs)
         | 
| 875 | 
            +
             | 
| 876 | 
            +
            print(f"Space built in {time.time() - start_time:.2f} seconds")
         | 
| 877 | 
            +
            # demo.launch(share=True)
         | 
| 878 | 
            +
            demo.launch(enable_queue=True, server_name="0.0.0.0", server_port=7860)
         | 
    	
        modules/lora.py
    ADDED
    
    | @@ -0,0 +1,183 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # LoRA network module
         | 
| 2 | 
            +
            # reference:
         | 
| 3 | 
            +
            # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
         | 
| 4 | 
            +
            # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
         | 
| 5 | 
            +
            # https://github.com/bmaltais/kohya_ss/blob/master/networks/lora.py#L48
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import math
         | 
| 8 | 
            +
            import os
         | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            import modules.safe as _
         | 
| 11 | 
            +
            from safetensors.torch import load_file
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class LoRAModule(torch.nn.Module):
         | 
| 15 | 
            +
                """
         | 
| 16 | 
            +
                replaces forward method of the original Linear, instead of replacing the original Linear module.
         | 
| 17 | 
            +
                """
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                def __init__(
         | 
| 20 | 
            +
                        self,
         | 
| 21 | 
            +
                        lora_name,
         | 
| 22 | 
            +
                        org_module: torch.nn.Module,
         | 
| 23 | 
            +
                        multiplier=1.0,
         | 
| 24 | 
            +
                        lora_dim=4,
         | 
| 25 | 
            +
                        alpha=1,
         | 
| 26 | 
            +
                ):
         | 
| 27 | 
            +
                    """if alpha == 0 or None, alpha is rank (no scaling)."""
         | 
| 28 | 
            +
                    super().__init__()
         | 
| 29 | 
            +
                    self.lora_name = lora_name
         | 
| 30 | 
            +
                    self.lora_dim = lora_dim
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    if org_module.__class__.__name__ == "Conv2d":
         | 
| 33 | 
            +
                        in_dim = org_module.in_channels
         | 
| 34 | 
            +
                        out_dim = org_module.out_channels
         | 
| 35 | 
            +
                        self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
         | 
| 36 | 
            +
                        self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
         | 
| 37 | 
            +
                    else:
         | 
| 38 | 
            +
                        in_dim = org_module.in_features
         | 
| 39 | 
            +
                        out_dim = org_module.out_features
         | 
| 40 | 
            +
                        self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
         | 
| 41 | 
            +
                        self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    if type(alpha) == torch.Tensor:
         | 
| 44 | 
            +
                        alpha = alpha.detach().float().numpy()  # without casting, bf16 causes error
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    alpha = lora_dim if alpha is None or alpha == 0 else alpha
         | 
| 47 | 
            +
                    self.scale = alpha / self.lora_dim
         | 
| 48 | 
            +
                    self.register_buffer("alpha", torch.tensor(alpha))  # 定数として扱える
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    # same as microsoft's
         | 
| 51 | 
            +
                    torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
         | 
| 52 | 
            +
                    torch.nn.init.zeros_(self.lora_up.weight)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    self.multiplier = multiplier
         | 
| 55 | 
            +
                    self.org_module = org_module  # remove in applying
         | 
| 56 | 
            +
                    self.enable = False
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def resize(self, rank, alpha, multiplier):
         | 
| 59 | 
            +
                    self.alpha = torch.tensor(alpha)
         | 
| 60 | 
            +
                    self.multiplier = multiplier
         | 
| 61 | 
            +
                    self.scale = alpha / rank
         | 
| 62 | 
            +
                    if self.lora_down.__class__.__name__ == "Conv2d":
         | 
| 63 | 
            +
                        in_dim = self.lora_down.in_channels
         | 
| 64 | 
            +
                        out_dim = self.lora_up.out_channels
         | 
| 65 | 
            +
                        self.lora_down = torch.nn.Conv2d(in_dim, rank, (1, 1), bias=False)
         | 
| 66 | 
            +
                        self.lora_up = torch.nn.Conv2d(rank, out_dim, (1, 1), bias=False)
         | 
| 67 | 
            +
                    else:
         | 
| 68 | 
            +
                        in_dim = self.lora_down.in_features
         | 
| 69 | 
            +
                        out_dim = self.lora_up.out_features
         | 
| 70 | 
            +
                        self.lora_down = torch.nn.Linear(in_dim, rank, bias=False)
         | 
| 71 | 
            +
                        self.lora_up = torch.nn.Linear(rank, out_dim, bias=False)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def apply(self):
         | 
| 74 | 
            +
                    if hasattr(self, "org_module"):
         | 
| 75 | 
            +
                        self.org_forward = self.org_module.forward
         | 
| 76 | 
            +
                        self.org_module.forward = self.forward
         | 
| 77 | 
            +
                        del self.org_module
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                def forward(self, x):
         | 
| 80 | 
            +
                    if self.enable:
         | 
| 81 | 
            +
                        return (
         | 
| 82 | 
            +
                    self.org_forward(x)
         | 
| 83 | 
            +
                    + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
         | 
| 84 | 
            +
                    )
         | 
| 85 | 
            +
                    return self.org_forward(x)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
             | 
| 88 | 
            +
            class LoRANetwork(torch.nn.Module):
         | 
| 89 | 
            +
                UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
         | 
| 90 | 
            +
                TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
         | 
| 91 | 
            +
                LORA_PREFIX_UNET = "lora_unet"
         | 
| 92 | 
            +
                LORA_PREFIX_TEXT_ENCODER = "lora_te"
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
         | 
| 95 | 
            +
                    super().__init__()
         | 
| 96 | 
            +
                    self.multiplier = multiplier
         | 
| 97 | 
            +
                    self.lora_dim = lora_dim
         | 
| 98 | 
            +
                    self.alpha = alpha
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    # create module instances
         | 
| 101 | 
            +
                    def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules):
         | 
| 102 | 
            +
                        loras = []
         | 
| 103 | 
            +
                        for name, module in root_module.named_modules():
         | 
| 104 | 
            +
                            if module.__class__.__name__ in target_replace_modules:
         | 
| 105 | 
            +
                                for child_name, child_module in module.named_modules():
         | 
| 106 | 
            +
                                    if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
         | 
| 107 | 
            +
                                        lora_name = prefix + "." + name + "." + child_name
         | 
| 108 | 
            +
                                        lora_name = lora_name.replace(".", "_")
         | 
| 109 | 
            +
                                        lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha,)
         | 
| 110 | 
            +
                                        loras.append(lora)
         | 
| 111 | 
            +
                        return loras
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    if isinstance(text_encoder, list):
         | 
| 114 | 
            +
                        self.text_encoder_loras = text_encoder
         | 
| 115 | 
            +
                    else:
         | 
| 116 | 
            +
                        self.text_encoder_loras = create_modules(LoRANetwork.LORA_PREFIX_TEXT_ENCODER, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
         | 
| 117 | 
            +
                        print(f"Create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE)
         | 
| 120 | 
            +
                    print(f"Create LoRA for U-Net: {len(self.unet_loras)} modules.")
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    self.weights_sd = None
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    # assertion
         | 
| 125 | 
            +
                    names = set()
         | 
| 126 | 
            +
                    for lora in self.text_encoder_loras + self.unet_loras:
         | 
| 127 | 
            +
                        assert (lora.lora_name not in names), f"duplicated lora name: {lora.lora_name}"
         | 
| 128 | 
            +
                        names.add(lora.lora_name)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                        lora.apply()
         | 
| 131 | 
            +
                        self.add_module(lora.lora_name, lora)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                def reset(self):
         | 
| 134 | 
            +
                    for lora in self.text_encoder_loras + self.unet_loras:
         | 
| 135 | 
            +
                        lora.enable = False
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                def load(self, file, scale):
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    weights = None
         | 
| 140 | 
            +
                    if os.path.splitext(file)[1] == ".safetensors":
         | 
| 141 | 
            +
                        weights = load_file(file)
         | 
| 142 | 
            +
                    else:
         | 
| 143 | 
            +
                        weights = torch.load(file, map_location="cpu")
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    if not weights:
         | 
| 146 | 
            +
                        return
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    network_alpha = None
         | 
| 149 | 
            +
                    network_dim = None
         | 
| 150 | 
            +
                    for key, value in weights.items():
         | 
| 151 | 
            +
                        if network_alpha is None and "alpha" in key:
         | 
| 152 | 
            +
                            network_alpha = value
         | 
| 153 | 
            +
                        if network_dim is None and "lora_down" in key and len(value.size()) == 2:
         | 
| 154 | 
            +
                            network_dim = value.size()[0]
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    if network_alpha is None:
         | 
| 157 | 
            +
                        network_alpha = network_dim
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    weights_has_text_encoder = weights_has_unet = False
         | 
| 160 | 
            +
                    weights_to_modify = []
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    for key in weights.keys():
         | 
| 163 | 
            +
                        if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
         | 
| 164 | 
            +
                            weights_has_text_encoder = True
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                        if key.startswith(LoRANetwork.LORA_PREFIX_UNET):
         | 
| 167 | 
            +
                            weights_has_unet = True
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    if weights_has_text_encoder:
         | 
| 170 | 
            +
                        weights_to_modify += self.text_encoder_loras
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    if weights_has_unet:
         | 
| 173 | 
            +
                        weights_to_modify += self.unet_loras
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    for lora in self.text_encoder_loras + self.unet_loras:
         | 
| 176 | 
            +
                        lora.resize(network_dim, network_alpha, scale)
         | 
| 177 | 
            +
                        if lora in weights_to_modify:
         | 
| 178 | 
            +
                            lora.enable = True
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                    info = self.load_state_dict(weights, False)
         | 
| 181 | 
            +
                    if len(info.unexpected_keys) > 0:
         | 
| 182 | 
            +
                        print(f"Weights are loaded. Unexpected keys={info.unexpected_keys}")
         | 
| 183 | 
            +
                        
         | 
    	
        modules/model.py
    ADDED
    
    | @@ -0,0 +1,897 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import importlib
         | 
| 2 | 
            +
            import inspect
         | 
| 3 | 
            +
            import math
         | 
| 4 | 
            +
            from pathlib import Path
         | 
| 5 | 
            +
            import re
         | 
| 6 | 
            +
            from collections import defaultdict
         | 
| 7 | 
            +
            from typing import List, Optional, Union
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import time
         | 
| 10 | 
            +
            import k_diffusion
         | 
| 11 | 
            +
            import numpy as np
         | 
| 12 | 
            +
            import PIL
         | 
| 13 | 
            +
            import torch
         | 
| 14 | 
            +
            import torch.nn as nn
         | 
| 15 | 
            +
            import torch.nn.functional as F
         | 
| 16 | 
            +
            from einops import rearrange
         | 
| 17 | 
            +
            from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
         | 
| 18 | 
            +
            from modules.prompt_parser import FrozenCLIPEmbedderWithCustomWords
         | 
| 19 | 
            +
            from torch import einsum
         | 
| 20 | 
            +
            from torch.autograd.function import Function
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from diffusers import DiffusionPipeline
         | 
| 23 | 
            +
            from diffusers.utils import PIL_INTERPOLATION, is_accelerate_available
         | 
| 24 | 
            +
            from diffusers.utils import logging, randn_tensor
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            import modules.safe as _
         | 
| 27 | 
            +
            from safetensors.torch import load_file
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            xformers_available = False
         | 
| 30 | 
            +
            try:
         | 
| 31 | 
            +
                import xformers
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                xformers_available = True
         | 
| 34 | 
            +
            except ImportError:
         | 
| 35 | 
            +
                pass
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            EPSILON = 1e-6
         | 
| 38 | 
            +
            exists = lambda val: val is not None
         | 
| 39 | 
            +
            default = lambda val, d: val if exists(val) else d
         | 
| 40 | 
            +
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            def get_attention_scores(attn, query, key, attention_mask=None):
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                if attn.upcast_attention:
         | 
| 46 | 
            +
                    query = query.float()
         | 
| 47 | 
            +
                    key = key.float()
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                attention_scores = torch.baddbmm(
         | 
| 50 | 
            +
                    torch.empty(
         | 
| 51 | 
            +
                        query.shape[0],
         | 
| 52 | 
            +
                        query.shape[1],
         | 
| 53 | 
            +
                        key.shape[1],
         | 
| 54 | 
            +
                        dtype=query.dtype,
         | 
| 55 | 
            +
                        device=query.device,
         | 
| 56 | 
            +
                    ),
         | 
| 57 | 
            +
                    query,
         | 
| 58 | 
            +
                    key.transpose(-1, -2),
         | 
| 59 | 
            +
                    beta=0,
         | 
| 60 | 
            +
                    alpha=attn.scale,
         | 
| 61 | 
            +
                )
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                if attention_mask is not None:
         | 
| 64 | 
            +
                    attention_scores = attention_scores + attention_mask
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                if attn.upcast_softmax:
         | 
| 67 | 
            +
                    attention_scores = attention_scores.float()
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                return attention_scores
         | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            class CrossAttnProcessor(nn.Module):
         | 
| 73 | 
            +
                def __call__(
         | 
| 74 | 
            +
                    self,
         | 
| 75 | 
            +
                    attn,
         | 
| 76 | 
            +
                    hidden_states,
         | 
| 77 | 
            +
                    encoder_hidden_states=None,
         | 
| 78 | 
            +
                    attention_mask=None,
         | 
| 79 | 
            +
                ):
         | 
| 80 | 
            +
                    batch_size, sequence_length, _ = hidden_states.shape
         | 
| 81 | 
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    encoder_states = hidden_states
         | 
| 84 | 
            +
                    is_xattn = False
         | 
| 85 | 
            +
                    if encoder_hidden_states is not None:
         | 
| 86 | 
            +
                        is_xattn = True
         | 
| 87 | 
            +
                        img_state = encoder_hidden_states["img_state"]
         | 
| 88 | 
            +
                        encoder_states = encoder_hidden_states["states"]
         | 
| 89 | 
            +
                        weight_func = encoder_hidden_states["weight_func"]
         | 
| 90 | 
            +
                        sigma = encoder_hidden_states["sigma"]
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    query = attn.to_q(hidden_states)
         | 
| 93 | 
            +
                    key = attn.to_k(encoder_states)
         | 
| 94 | 
            +
                    value = attn.to_v(encoder_states)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    query = attn.head_to_batch_dim(query)
         | 
| 97 | 
            +
                    key = attn.head_to_batch_dim(key)
         | 
| 98 | 
            +
                    value = attn.head_to_batch_dim(value)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    if is_xattn and isinstance(img_state, dict):
         | 
| 101 | 
            +
                        # use torch.baddbmm method (slow)
         | 
| 102 | 
            +
                        attention_scores = get_attention_scores(attn, query, key, attention_mask)
         | 
| 103 | 
            +
                        w = img_state[sequence_length].to(query.device)
         | 
| 104 | 
            +
                        cross_attention_weight = weight_func(w, sigma, attention_scores)
         | 
| 105 | 
            +
                        attention_scores += torch.repeat_interleave(
         | 
| 106 | 
            +
                            cross_attention_weight, repeats=attn.heads, dim=0
         | 
| 107 | 
            +
                        )
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                        # calc probs
         | 
| 110 | 
            +
                        attention_probs = attention_scores.softmax(dim=-1)
         | 
| 111 | 
            +
                        attention_probs = attention_probs.to(query.dtype)
         | 
| 112 | 
            +
                        hidden_states = torch.bmm(attention_probs, value)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    elif xformers_available:
         | 
| 115 | 
            +
                        hidden_states = xformers.ops.memory_efficient_attention(
         | 
| 116 | 
            +
                            query.contiguous(),
         | 
| 117 | 
            +
                            key.contiguous(),
         | 
| 118 | 
            +
                            value.contiguous(),
         | 
| 119 | 
            +
                            attn_bias=attention_mask,
         | 
| 120 | 
            +
                        )
         | 
| 121 | 
            +
                        hidden_states = hidden_states.to(query.dtype)
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    else:
         | 
| 124 | 
            +
                        q_bucket_size = 512
         | 
| 125 | 
            +
                        k_bucket_size = 1024
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                        # use flash-attention
         | 
| 128 | 
            +
                        hidden_states = FlashAttentionFunction.apply(
         | 
| 129 | 
            +
                            query.contiguous(),
         | 
| 130 | 
            +
                            key.contiguous(),
         | 
| 131 | 
            +
                            value.contiguous(),
         | 
| 132 | 
            +
                            attention_mask,
         | 
| 133 | 
            +
                            False,
         | 
| 134 | 
            +
                            q_bucket_size,
         | 
| 135 | 
            +
                            k_bucket_size,
         | 
| 136 | 
            +
                        )
         | 
| 137 | 
            +
                        hidden_states = hidden_states.to(query.dtype)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    # linear proj
         | 
| 142 | 
            +
                    hidden_states = attn.to_out[0](hidden_states)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    # dropout
         | 
| 145 | 
            +
                    hidden_states = attn.to_out[1](hidden_states)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    return hidden_states
         | 
| 148 | 
            +
             | 
| 149 | 
            +
            class ModelWrapper:
         | 
| 150 | 
            +
                def __init__(self, model, alphas_cumprod):
         | 
| 151 | 
            +
                    self.model = model
         | 
| 152 | 
            +
                    self.alphas_cumprod = alphas_cumprod
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                def apply_model(self, *args, **kwargs):
         | 
| 155 | 
            +
                    if len(args) == 3:
         | 
| 156 | 
            +
                        encoder_hidden_states = args[-1]
         | 
| 157 | 
            +
                        args = args[:2]
         | 
| 158 | 
            +
                    if kwargs.get("cond", None) is not None:
         | 
| 159 | 
            +
                        encoder_hidden_states = kwargs.pop("cond")
         | 
| 160 | 
            +
                    return self.model(
         | 
| 161 | 
            +
                        *args, encoder_hidden_states=encoder_hidden_states, **kwargs
         | 
| 162 | 
            +
                    ).sample
         | 
| 163 | 
            +
             | 
| 164 | 
            +
             | 
| 165 | 
            +
            class StableDiffusionPipeline(DiffusionPipeline):
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                _optional_components = ["safety_checker", "feature_extractor"]
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                def __init__(
         | 
| 170 | 
            +
                    self,
         | 
| 171 | 
            +
                    vae,
         | 
| 172 | 
            +
                    text_encoder,
         | 
| 173 | 
            +
                    tokenizer,
         | 
| 174 | 
            +
                    unet,
         | 
| 175 | 
            +
                    scheduler,
         | 
| 176 | 
            +
                ):
         | 
| 177 | 
            +
                    super().__init__()
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    # get correct sigmas from LMS
         | 
| 180 | 
            +
                    self.register_modules(
         | 
| 181 | 
            +
                        vae=vae,
         | 
| 182 | 
            +
                        text_encoder=text_encoder,
         | 
| 183 | 
            +
                        tokenizer=tokenizer,
         | 
| 184 | 
            +
                        unet=unet,
         | 
| 185 | 
            +
                        scheduler=scheduler,
         | 
| 186 | 
            +
                    )
         | 
| 187 | 
            +
                    self.setup_unet(self.unet)
         | 
| 188 | 
            +
                    self.setup_text_encoder()
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                def setup_text_encoder(self, n=1, new_encoder=None):
         | 
| 191 | 
            +
                    if new_encoder is not None:
         | 
| 192 | 
            +
                        self.text_encoder = new_encoder
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                    self.prompt_parser = FrozenCLIPEmbedderWithCustomWords(self.tokenizer, self.text_encoder)
         | 
| 195 | 
            +
                    self.prompt_parser.CLIP_stop_at_last_layers = n
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                def setup_unet(self, unet):
         | 
| 198 | 
            +
                    unet = unet.to(self.device)
         | 
| 199 | 
            +
                    model = ModelWrapper(unet, self.scheduler.alphas_cumprod)
         | 
| 200 | 
            +
                    if self.scheduler.prediction_type == "v_prediction":
         | 
| 201 | 
            +
                        self.k_diffusion_model = CompVisVDenoiser(model)
         | 
| 202 | 
            +
                    else:
         | 
| 203 | 
            +
                        self.k_diffusion_model = CompVisDenoiser(model)
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                def get_scheduler(self, scheduler_type: str):
         | 
| 206 | 
            +
                    library = importlib.import_module("k_diffusion")
         | 
| 207 | 
            +
                    sampling = getattr(library, "sampling")
         | 
| 208 | 
            +
                    return getattr(sampling, scheduler_type)
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                def encode_sketchs(self, state, scale_ratio=8, g_strength=1.0, text_ids=None):
         | 
| 211 | 
            +
                    uncond, cond = text_ids[0], text_ids[1]
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    img_state = []
         | 
| 214 | 
            +
                    if state is None:
         | 
| 215 | 
            +
                        return torch.FloatTensor(0)
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                    for k, v in state.items():
         | 
| 218 | 
            +
                        if v["map"] is None:
         | 
| 219 | 
            +
                            continue
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                        v_input = self.tokenizer(
         | 
| 222 | 
            +
                            k,
         | 
| 223 | 
            +
                            max_length=self.tokenizer.model_max_length,
         | 
| 224 | 
            +
                            truncation=True,
         | 
| 225 | 
            +
                            add_special_tokens=False,
         | 
| 226 | 
            +
                        ).input_ids
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                        dotmap = v["map"] < 255
         | 
| 229 | 
            +
                        out = dotmap.astype(float)
         | 
| 230 | 
            +
                        if v["mask_outsides"]:
         | 
| 231 | 
            +
                            out[out==0] = -1
         | 
| 232 | 
            +
                            
         | 
| 233 | 
            +
                        arr = torch.from_numpy(
         | 
| 234 | 
            +
                            out * float(v["weight"]) * g_strength
         | 
| 235 | 
            +
                        )
         | 
| 236 | 
            +
                        img_state.append((v_input, arr))
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                    if len(img_state) == 0:
         | 
| 239 | 
            +
                        return torch.FloatTensor(0)
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                    w_tensors = dict()
         | 
| 242 | 
            +
                    cond = cond.tolist()
         | 
| 243 | 
            +
                    uncond = uncond.tolist()
         | 
| 244 | 
            +
                    for layer in self.unet.down_blocks:
         | 
| 245 | 
            +
                        c = int(len(cond))
         | 
| 246 | 
            +
                        w, h = img_state[0][1].shape
         | 
| 247 | 
            +
                        w_r, h_r = w // scale_ratio, h // scale_ratio
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                        ret_cond_tensor = torch.zeros((1, int(w_r * h_r), c), dtype=torch.float32)
         | 
| 250 | 
            +
                        ret_uncond_tensor = torch.zeros((1, int(w_r * h_r), c), dtype=torch.float32)
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                        for v_as_tokens, img_where_color in img_state:
         | 
| 253 | 
            +
                            is_in = 0
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                            ret = (
         | 
| 256 | 
            +
                                F.interpolate(
         | 
| 257 | 
            +
                                    img_where_color.unsqueeze(0).unsqueeze(1),
         | 
| 258 | 
            +
                                    scale_factor=1 / scale_ratio,
         | 
| 259 | 
            +
                                    mode="bilinear",
         | 
| 260 | 
            +
                                    align_corners=True,
         | 
| 261 | 
            +
                                )
         | 
| 262 | 
            +
                                .squeeze()
         | 
| 263 | 
            +
                                .reshape(-1, 1)
         | 
| 264 | 
            +
                                .repeat(1, len(v_as_tokens))
         | 
| 265 | 
            +
                            )
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                            for idx, tok in enumerate(cond):
         | 
| 268 | 
            +
                                if cond[idx : idx + len(v_as_tokens)] == v_as_tokens:
         | 
| 269 | 
            +
                                    is_in = 1
         | 
| 270 | 
            +
                                    ret_cond_tensor[0, :, idx : idx + len(v_as_tokens)] += ret
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                            for idx, tok in enumerate(uncond):
         | 
| 273 | 
            +
                                if uncond[idx : idx + len(v_as_tokens)] == v_as_tokens:
         | 
| 274 | 
            +
                                    is_in = 1
         | 
| 275 | 
            +
                                    ret_uncond_tensor[0, :, idx : idx + len(v_as_tokens)] += ret
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                            if not is_in == 1:
         | 
| 278 | 
            +
                                print(f"tokens {v_as_tokens} not found in text")
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                        w_tensors[w_r * h_r] = torch.cat([ret_uncond_tensor, ret_cond_tensor])
         | 
| 281 | 
            +
                        scale_ratio *= 2
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                    return w_tensors
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
         | 
| 286 | 
            +
                    r"""
         | 
| 287 | 
            +
                    Enable sliced attention computation.
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                    When this option is enabled, the attention module will split the input tensor in slices, to compute attention
         | 
| 290 | 
            +
                    in several steps. This is useful to save some memory in exchange for a small speed decrease.
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                    Args:
         | 
| 293 | 
            +
                        slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
         | 
| 294 | 
            +
                            When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
         | 
| 295 | 
            +
                            a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
         | 
| 296 | 
            +
                            `attention_head_dim` must be a multiple of `slice_size`.
         | 
| 297 | 
            +
                    """
         | 
| 298 | 
            +
                    if slice_size == "auto":
         | 
| 299 | 
            +
                        # half the attention head size is usually a good trade-off between
         | 
| 300 | 
            +
                        # speed and memory
         | 
| 301 | 
            +
                        slice_size = self.unet.config.attention_head_dim // 2
         | 
| 302 | 
            +
                    self.unet.set_attention_slice(slice_size)
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                def disable_attention_slicing(self):
         | 
| 305 | 
            +
                    r"""
         | 
| 306 | 
            +
                    Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
         | 
| 307 | 
            +
                    back to computing attention in one step.
         | 
| 308 | 
            +
                    """
         | 
| 309 | 
            +
                    # set slice_size = `None` to disable `attention slicing`
         | 
| 310 | 
            +
                    self.enable_attention_slicing(None)
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                def enable_sequential_cpu_offload(self, gpu_id=0):
         | 
| 313 | 
            +
                    r"""
         | 
| 314 | 
            +
                    Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
         | 
| 315 | 
            +
                    text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
         | 
| 316 | 
            +
                    `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
         | 
| 317 | 
            +
                    """
         | 
| 318 | 
            +
                    if is_accelerate_available():
         | 
| 319 | 
            +
                        from accelerate import cpu_offload
         | 
| 320 | 
            +
                    else:
         | 
| 321 | 
            +
                        raise ImportError("Please install accelerate via `pip install accelerate`")
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                    device = torch.device(f"cuda:{gpu_id}")
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                    for cpu_offloaded_model in [
         | 
| 326 | 
            +
                        self.unet,
         | 
| 327 | 
            +
                        self.text_encoder,
         | 
| 328 | 
            +
                        self.vae,
         | 
| 329 | 
            +
                        self.safety_checker,
         | 
| 330 | 
            +
                    ]:
         | 
| 331 | 
            +
                        if cpu_offloaded_model is not None:
         | 
| 332 | 
            +
                            cpu_offload(cpu_offloaded_model, device)
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                @property
         | 
| 335 | 
            +
                def _execution_device(self):
         | 
| 336 | 
            +
                    r"""
         | 
| 337 | 
            +
                    Returns the device on which the pipeline's models will be executed. After calling
         | 
| 338 | 
            +
                    `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
         | 
| 339 | 
            +
                    hooks.
         | 
| 340 | 
            +
                    """
         | 
| 341 | 
            +
                    if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
         | 
| 342 | 
            +
                        return self.device
         | 
| 343 | 
            +
                    for module in self.unet.modules():
         | 
| 344 | 
            +
                        if (
         | 
| 345 | 
            +
                            hasattr(module, "_hf_hook")
         | 
| 346 | 
            +
                            and hasattr(module._hf_hook, "execution_device")
         | 
| 347 | 
            +
                            and module._hf_hook.execution_device is not None
         | 
| 348 | 
            +
                        ):
         | 
| 349 | 
            +
                            return torch.device(module._hf_hook.execution_device)
         | 
| 350 | 
            +
                    return self.device
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                def decode_latents(self, latents):
         | 
| 353 | 
            +
                    latents = latents.to(self.device, dtype=self.vae.dtype)
         | 
| 354 | 
            +
                    latents = 1 / 0.18215 * latents
         | 
| 355 | 
            +
                    image = self.vae.decode(latents).sample
         | 
| 356 | 
            +
                    image = (image / 2 + 0.5).clamp(0, 1)
         | 
| 357 | 
            +
                    # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
         | 
| 358 | 
            +
                    image = image.cpu().permute(0, 2, 3, 1).float().numpy()
         | 
| 359 | 
            +
                    return image
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                def check_inputs(self, prompt, height, width, callback_steps):
         | 
| 362 | 
            +
                    if not isinstance(prompt, str) and not isinstance(prompt, list):
         | 
| 363 | 
            +
                        raise ValueError(
         | 
| 364 | 
            +
                            f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
         | 
| 365 | 
            +
                        )
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                    if height % 8 != 0 or width % 8 != 0:
         | 
| 368 | 
            +
                        raise ValueError(
         | 
| 369 | 
            +
                            f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
         | 
| 370 | 
            +
                        )
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                    if (callback_steps is None) or (
         | 
| 373 | 
            +
                        callback_steps is not None
         | 
| 374 | 
            +
                        and (not isinstance(callback_steps, int) or callback_steps <= 0)
         | 
| 375 | 
            +
                    ):
         | 
| 376 | 
            +
                        raise ValueError(
         | 
| 377 | 
            +
                            f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
         | 
| 378 | 
            +
                            f" {type(callback_steps)}."
         | 
| 379 | 
            +
                        )
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                def prepare_latents(
         | 
| 382 | 
            +
                    self,
         | 
| 383 | 
            +
                    batch_size,
         | 
| 384 | 
            +
                    num_channels_latents,
         | 
| 385 | 
            +
                    height,
         | 
| 386 | 
            +
                    width,
         | 
| 387 | 
            +
                    dtype,
         | 
| 388 | 
            +
                    device,
         | 
| 389 | 
            +
                    generator,
         | 
| 390 | 
            +
                    latents=None,
         | 
| 391 | 
            +
                ):
         | 
| 392 | 
            +
                    shape = (batch_size, num_channels_latents, height // 8, width // 8)
         | 
| 393 | 
            +
                    if latents is None:
         | 
| 394 | 
            +
                        if device.type == "mps":
         | 
| 395 | 
            +
                            # randn does not work reproducibly on mps
         | 
| 396 | 
            +
                            latents = torch.randn(
         | 
| 397 | 
            +
                                shape, generator=generator, device="cpu", dtype=dtype
         | 
| 398 | 
            +
                            ).to(device)
         | 
| 399 | 
            +
                        else:
         | 
| 400 | 
            +
                            latents = torch.randn(
         | 
| 401 | 
            +
                                shape, generator=generator, device=device, dtype=dtype
         | 
| 402 | 
            +
                            )
         | 
| 403 | 
            +
                    else:
         | 
| 404 | 
            +
                        # if latents.shape != shape:
         | 
| 405 | 
            +
                        #     raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
         | 
| 406 | 
            +
                        latents = latents.to(device)
         | 
| 407 | 
            +
             | 
| 408 | 
            +
                    # scale the initial noise by the standard deviation required by the scheduler
         | 
| 409 | 
            +
                    return latents
         | 
| 410 | 
            +
             | 
| 411 | 
            +
                def preprocess(self, image):
         | 
| 412 | 
            +
                    if isinstance(image, torch.Tensor):
         | 
| 413 | 
            +
                        return image
         | 
| 414 | 
            +
                    elif isinstance(image, PIL.Image.Image):
         | 
| 415 | 
            +
                        image = [image]
         | 
| 416 | 
            +
             | 
| 417 | 
            +
                    if isinstance(image[0], PIL.Image.Image):
         | 
| 418 | 
            +
                        w, h = image[0].size
         | 
| 419 | 
            +
                        w, h = map(lambda x: x - x % 8, (w, h))  # resize to integer multiple of 8
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                        image = [
         | 
| 422 | 
            +
                            np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[
         | 
| 423 | 
            +
                                None, :
         | 
| 424 | 
            +
                            ]
         | 
| 425 | 
            +
                            for i in image
         | 
| 426 | 
            +
                        ]
         | 
| 427 | 
            +
                        image = np.concatenate(image, axis=0)
         | 
| 428 | 
            +
                        image = np.array(image).astype(np.float32) / 255.0
         | 
| 429 | 
            +
                        image = image.transpose(0, 3, 1, 2)
         | 
| 430 | 
            +
                        image = 2.0 * image - 1.0
         | 
| 431 | 
            +
                        image = torch.from_numpy(image)
         | 
| 432 | 
            +
                    elif isinstance(image[0], torch.Tensor):
         | 
| 433 | 
            +
                        image = torch.cat(image, dim=0)
         | 
| 434 | 
            +
                    return image
         | 
| 435 | 
            +
             | 
| 436 | 
            +
                @torch.no_grad()
         | 
| 437 | 
            +
                def img2img(
         | 
| 438 | 
            +
                    self,
         | 
| 439 | 
            +
                    prompt: Union[str, List[str]],
         | 
| 440 | 
            +
                    num_inference_steps: int = 50,
         | 
| 441 | 
            +
                    guidance_scale: float = 7.5,
         | 
| 442 | 
            +
                    negative_prompt: Optional[Union[str, List[str]]] = None,
         | 
| 443 | 
            +
                    generator: Optional[torch.Generator] = None,
         | 
| 444 | 
            +
                    image: Optional[torch.FloatTensor] = None,
         | 
| 445 | 
            +
                    output_type: Optional[str] = "pil",
         | 
| 446 | 
            +
                    latents=None,
         | 
| 447 | 
            +
                    strength=1.0,
         | 
| 448 | 
            +
                    pww_state=None,
         | 
| 449 | 
            +
                    pww_attn_weight=1.0,
         | 
| 450 | 
            +
                    sampler_name="",
         | 
| 451 | 
            +
                    sampler_opt={},
         | 
| 452 | 
            +
                    start_time=-1,
         | 
| 453 | 
            +
                    timeout=180,
         | 
| 454 | 
            +
                    scale_ratio=8.0,
         | 
| 455 | 
            +
                ):
         | 
| 456 | 
            +
                    sampler = self.get_scheduler(sampler_name)
         | 
| 457 | 
            +
                    if image is not None:
         | 
| 458 | 
            +
                        image = self.preprocess(image)
         | 
| 459 | 
            +
                        image = image.to(self.vae.device, dtype=self.vae.dtype)
         | 
| 460 | 
            +
             | 
| 461 | 
            +
                        init_latents = self.vae.encode(image).latent_dist.sample(generator)
         | 
| 462 | 
            +
                        latents = 0.18215 * init_latents
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                    # 2. Define call parameters
         | 
| 465 | 
            +
                    batch_size = 1 if isinstance(prompt, str) else len(prompt)
         | 
| 466 | 
            +
                    device = self._execution_device
         | 
| 467 | 
            +
                    latents = latents.to(device, dtype=self.unet.dtype)
         | 
| 468 | 
            +
                    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
         | 
| 469 | 
            +
                    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
         | 
| 470 | 
            +
                    # corresponds to doing no classifier free guidance.
         | 
| 471 | 
            +
                    do_classifier_free_guidance = True
         | 
| 472 | 
            +
                    if guidance_scale <= 1.0:
         | 
| 473 | 
            +
                        raise ValueError("has to use guidance_scale")
         | 
| 474 | 
            +
             | 
| 475 | 
            +
                    # 3. Encode input prompt
         | 
| 476 | 
            +
                    text_ids, text_embeddings = self.prompt_parser([negative_prompt, prompt])
         | 
| 477 | 
            +
                    text_embeddings = text_embeddings.to(self.unet.dtype)
         | 
| 478 | 
            +
             | 
| 479 | 
            +
                    init_timestep = (
         | 
| 480 | 
            +
                        int(num_inference_steps / min(strength, 0.999)) if strength > 0 else 0
         | 
| 481 | 
            +
                    )
         | 
| 482 | 
            +
                    sigmas = self.get_sigmas(init_timestep, sampler_opt).to(
         | 
| 483 | 
            +
                        text_embeddings.device, dtype=text_embeddings.dtype
         | 
| 484 | 
            +
                    )
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                    t_start = max(init_timestep - num_inference_steps, 0)
         | 
| 487 | 
            +
                    sigma_sched = sigmas[t_start:]
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                    noise = randn_tensor(
         | 
| 490 | 
            +
                        latents.shape,
         | 
| 491 | 
            +
                        generator=generator,
         | 
| 492 | 
            +
                        device=device,
         | 
| 493 | 
            +
                        dtype=text_embeddings.dtype,
         | 
| 494 | 
            +
                    )
         | 
| 495 | 
            +
                    latents = latents.to(device)
         | 
| 496 | 
            +
                    latents = latents + noise * sigma_sched[0]
         | 
| 497 | 
            +
             | 
| 498 | 
            +
                    # 5. Prepare latent variables
         | 
| 499 | 
            +
                    self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
         | 
| 500 | 
            +
                    self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(
         | 
| 501 | 
            +
                        latents.device
         | 
| 502 | 
            +
                    )
         | 
| 503 | 
            +
             | 
| 504 | 
            +
                    img_state = self.encode_sketchs(
         | 
| 505 | 
            +
                        pww_state,
         | 
| 506 | 
            +
                        g_strength=pww_attn_weight,
         | 
| 507 | 
            +
                        text_ids=text_ids,
         | 
| 508 | 
            +
                    )
         | 
| 509 | 
            +
             | 
| 510 | 
            +
                    def model_fn(x, sigma):
         | 
| 511 | 
            +
             | 
| 512 | 
            +
                        if start_time > 0 and timeout > 0:
         | 
| 513 | 
            +
                            assert (time.time() - start_time) < timeout, "inference process timed out"
         | 
| 514 | 
            +
             | 
| 515 | 
            +
                        latent_model_input = torch.cat([x] * 2)
         | 
| 516 | 
            +
                        weight_func = lambda w, sigma, qk: w * math.log(1 + sigma) * qk.max()
         | 
| 517 | 
            +
                        encoder_state = {
         | 
| 518 | 
            +
                            "img_state": img_state,
         | 
| 519 | 
            +
                            "states": text_embeddings,
         | 
| 520 | 
            +
                            "sigma": sigma[0],
         | 
| 521 | 
            +
                            "weight_func": weight_func,
         | 
| 522 | 
            +
                        }
         | 
| 523 | 
            +
             | 
| 524 | 
            +
                        noise_pred = self.k_diffusion_model(
         | 
| 525 | 
            +
                            latent_model_input, sigma, cond=encoder_state
         | 
| 526 | 
            +
                        )
         | 
| 527 | 
            +
                        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
         | 
| 528 | 
            +
                        noise_pred = noise_pred_uncond + guidance_scale * (
         | 
| 529 | 
            +
                            noise_pred_text - noise_pred_uncond
         | 
| 530 | 
            +
                        )
         | 
| 531 | 
            +
                        return noise_pred
         | 
| 532 | 
            +
             | 
| 533 | 
            +
                    sampler_args = self.get_sampler_extra_args_i2i(sigma_sched, sampler)
         | 
| 534 | 
            +
                    latents = sampler(model_fn, latents, **sampler_args)
         | 
| 535 | 
            +
             | 
| 536 | 
            +
                    # 8. Post-processing
         | 
| 537 | 
            +
                    image = self.decode_latents(latents)
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                    # 10. Convert to PIL
         | 
| 540 | 
            +
                    if output_type == "pil":
         | 
| 541 | 
            +
                        image = self.numpy_to_pil(image)
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                    return (image,)
         | 
| 544 | 
            +
             | 
| 545 | 
            +
                def get_sigmas(self, steps, params):
         | 
| 546 | 
            +
                    discard_next_to_last_sigma = params.get("discard_next_to_last_sigma", False)
         | 
| 547 | 
            +
                    steps += 1 if discard_next_to_last_sigma else 0
         | 
| 548 | 
            +
             | 
| 549 | 
            +
                    if params.get("scheduler", None) == "karras":
         | 
| 550 | 
            +
                        sigma_min, sigma_max = (
         | 
| 551 | 
            +
                            self.k_diffusion_model.sigmas[0].item(),
         | 
| 552 | 
            +
                            self.k_diffusion_model.sigmas[-1].item(),
         | 
| 553 | 
            +
                        )
         | 
| 554 | 
            +
                        sigmas = k_diffusion.sampling.get_sigmas_karras(
         | 
| 555 | 
            +
                            n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=self.device
         | 
| 556 | 
            +
                        )
         | 
| 557 | 
            +
                    else:
         | 
| 558 | 
            +
                        sigmas = self.k_diffusion_model.get_sigmas(steps)
         | 
| 559 | 
            +
             | 
| 560 | 
            +
                    if discard_next_to_last_sigma:
         | 
| 561 | 
            +
                        sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
         | 
| 562 | 
            +
             | 
| 563 | 
            +
                    return sigmas
         | 
| 564 | 
            +
             | 
| 565 | 
            +
                # https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/48a15821de768fea76e66f26df83df3fddf18f4b/modules/sd_samplers.py#L454
         | 
| 566 | 
            +
                def get_sampler_extra_args_t2i(self, sigmas, eta, steps, func):
         | 
| 567 | 
            +
                    extra_params_kwargs = {}
         | 
| 568 | 
            +
             | 
| 569 | 
            +
                    if "eta" in inspect.signature(func).parameters:
         | 
| 570 | 
            +
                        extra_params_kwargs["eta"] = eta
         | 
| 571 | 
            +
             | 
| 572 | 
            +
                    if "sigma_min" in inspect.signature(func).parameters:
         | 
| 573 | 
            +
                        extra_params_kwargs["sigma_min"] = sigmas[0].item()
         | 
| 574 | 
            +
                        extra_params_kwargs["sigma_max"] = sigmas[-1].item()
         | 
| 575 | 
            +
             | 
| 576 | 
            +
                    if "n" in inspect.signature(func).parameters:
         | 
| 577 | 
            +
                        extra_params_kwargs["n"] = steps
         | 
| 578 | 
            +
                    else:
         | 
| 579 | 
            +
                        extra_params_kwargs["sigmas"] = sigmas
         | 
| 580 | 
            +
             | 
| 581 | 
            +
                    return extra_params_kwargs
         | 
| 582 | 
            +
             | 
| 583 | 
            +
                # https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/48a15821de768fea76e66f26df83df3fddf18f4b/modules/sd_samplers.py#L454
         | 
| 584 | 
            +
                def get_sampler_extra_args_i2i(self, sigmas, func):
         | 
| 585 | 
            +
                    extra_params_kwargs = {}
         | 
| 586 | 
            +
             | 
| 587 | 
            +
                    if "sigma_min" in inspect.signature(func).parameters:
         | 
| 588 | 
            +
                        ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
         | 
| 589 | 
            +
                        extra_params_kwargs["sigma_min"] = sigmas[-2]
         | 
| 590 | 
            +
             | 
| 591 | 
            +
                    if "sigma_max" in inspect.signature(func).parameters:
         | 
| 592 | 
            +
                        extra_params_kwargs["sigma_max"] = sigmas[0]
         | 
| 593 | 
            +
             | 
| 594 | 
            +
                    if "n" in inspect.signature(func).parameters:
         | 
| 595 | 
            +
                        extra_params_kwargs["n"] = len(sigmas) - 1
         | 
| 596 | 
            +
             | 
| 597 | 
            +
                    if "sigma_sched" in inspect.signature(func).parameters:
         | 
| 598 | 
            +
                        extra_params_kwargs["sigma_sched"] = sigmas
         | 
| 599 | 
            +
             | 
| 600 | 
            +
                    if "sigmas" in inspect.signature(func).parameters:
         | 
| 601 | 
            +
                        extra_params_kwargs["sigmas"] = sigmas
         | 
| 602 | 
            +
             | 
| 603 | 
            +
                    return extra_params_kwargs
         | 
| 604 | 
            +
             | 
| 605 | 
            +
                @torch.no_grad()
         | 
| 606 | 
            +
                def txt2img(
         | 
| 607 | 
            +
                    self,
         | 
| 608 | 
            +
                    prompt: Union[str, List[str]],
         | 
| 609 | 
            +
                    height: int = 512,
         | 
| 610 | 
            +
                    width: int = 512,
         | 
| 611 | 
            +
                    num_inference_steps: int = 50,
         | 
| 612 | 
            +
                    guidance_scale: float = 7.5,
         | 
| 613 | 
            +
                    negative_prompt: Optional[Union[str, List[str]]] = None,
         | 
| 614 | 
            +
                    eta: float = 0.0,
         | 
| 615 | 
            +
                    generator: Optional[torch.Generator] = None,
         | 
| 616 | 
            +
                    latents: Optional[torch.FloatTensor] = None,
         | 
| 617 | 
            +
                    output_type: Optional[str] = "pil",
         | 
| 618 | 
            +
                    callback_steps: Optional[int] = 1,
         | 
| 619 | 
            +
                    upscale=False,
         | 
| 620 | 
            +
                    upscale_x: float = 2.0,
         | 
| 621 | 
            +
                    upscale_method: str = "bicubic",
         | 
| 622 | 
            +
                    upscale_antialias: bool = False,
         | 
| 623 | 
            +
                    upscale_denoising_strength: int = 0.7,
         | 
| 624 | 
            +
                    pww_state=None,
         | 
| 625 | 
            +
                    pww_attn_weight=1.0,
         | 
| 626 | 
            +
                    sampler_name="",
         | 
| 627 | 
            +
                    sampler_opt={},
         | 
| 628 | 
            +
                    start_time=-1,
         | 
| 629 | 
            +
                    timeout=180,
         | 
| 630 | 
            +
                ):
         | 
| 631 | 
            +
                    sampler = self.get_scheduler(sampler_name)
         | 
| 632 | 
            +
                    # 1. Check inputs. Raise error if not correct
         | 
| 633 | 
            +
                    self.check_inputs(prompt, height, width, callback_steps)
         | 
| 634 | 
            +
             | 
| 635 | 
            +
                    # 2. Define call parameters
         | 
| 636 | 
            +
                    batch_size = 1 if isinstance(prompt, str) else len(prompt)
         | 
| 637 | 
            +
                    device = self._execution_device
         | 
| 638 | 
            +
                    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
         | 
| 639 | 
            +
                    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
         | 
| 640 | 
            +
                    # corresponds to doing no classifier free guidance.
         | 
| 641 | 
            +
                    do_classifier_free_guidance = True
         | 
| 642 | 
            +
                    if guidance_scale <= 1.0:
         | 
| 643 | 
            +
                        raise ValueError("has to use guidance_scale")
         | 
| 644 | 
            +
             | 
| 645 | 
            +
                    # 3. Encode input prompt
         | 
| 646 | 
            +
                    text_ids, text_embeddings = self.prompt_parser([negative_prompt, prompt])
         | 
| 647 | 
            +
                    text_embeddings = text_embeddings.to(self.unet.dtype)
         | 
| 648 | 
            +
             | 
| 649 | 
            +
                    # 4. Prepare timesteps
         | 
| 650 | 
            +
                    sigmas = self.get_sigmas(num_inference_steps, sampler_opt).to(
         | 
| 651 | 
            +
                        text_embeddings.device, dtype=text_embeddings.dtype
         | 
| 652 | 
            +
                    )
         | 
| 653 | 
            +
             | 
| 654 | 
            +
                    # 5. Prepare latent variables
         | 
| 655 | 
            +
                    num_channels_latents = self.unet.in_channels
         | 
| 656 | 
            +
                    latents = self.prepare_latents(
         | 
| 657 | 
            +
                        batch_size,
         | 
| 658 | 
            +
                        num_channels_latents,
         | 
| 659 | 
            +
                        height,
         | 
| 660 | 
            +
                        width,
         | 
| 661 | 
            +
                        text_embeddings.dtype,
         | 
| 662 | 
            +
                        device,
         | 
| 663 | 
            +
                        generator,
         | 
| 664 | 
            +
                        latents,
         | 
| 665 | 
            +
                    )
         | 
| 666 | 
            +
                    latents = latents * sigmas[0]
         | 
| 667 | 
            +
                    self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
         | 
| 668 | 
            +
                    self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(
         | 
| 669 | 
            +
                        latents.device
         | 
| 670 | 
            +
                    )
         | 
| 671 | 
            +
             | 
| 672 | 
            +
                    img_state = self.encode_sketchs(
         | 
| 673 | 
            +
                        pww_state,
         | 
| 674 | 
            +
                        g_strength=pww_attn_weight,
         | 
| 675 | 
            +
                        text_ids=text_ids,
         | 
| 676 | 
            +
                    )
         | 
| 677 | 
            +
             | 
| 678 | 
            +
                    def model_fn(x, sigma):
         | 
| 679 | 
            +
             | 
| 680 | 
            +
                        if start_time > 0 and timeout > 0:
         | 
| 681 | 
            +
                            assert (time.time() - start_time) < timeout, "inference process timed out"
         | 
| 682 | 
            +
             | 
| 683 | 
            +
                        latent_model_input = torch.cat([x] * 2)
         | 
| 684 | 
            +
                        weight_func = lambda w, sigma, qk: w * math.log(1 + sigma) * qk.max()
         | 
| 685 | 
            +
                        encoder_state = {
         | 
| 686 | 
            +
                            "img_state": img_state,
         | 
| 687 | 
            +
                            "states": text_embeddings,
         | 
| 688 | 
            +
                            "sigma": sigma[0],
         | 
| 689 | 
            +
                            "weight_func": weight_func,
         | 
| 690 | 
            +
                        }
         | 
| 691 | 
            +
             | 
| 692 | 
            +
                        noise_pred = self.k_diffusion_model(
         | 
| 693 | 
            +
                            latent_model_input, sigma, cond=encoder_state
         | 
| 694 | 
            +
                        )
         | 
| 695 | 
            +
                        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
         | 
| 696 | 
            +
                        noise_pred = noise_pred_uncond + guidance_scale * (
         | 
| 697 | 
            +
                            noise_pred_text - noise_pred_uncond
         | 
| 698 | 
            +
                        )
         | 
| 699 | 
            +
                        return noise_pred
         | 
| 700 | 
            +
             | 
| 701 | 
            +
                    extra_args = self.get_sampler_extra_args_t2i(
         | 
| 702 | 
            +
                        sigmas, eta, num_inference_steps, sampler
         | 
| 703 | 
            +
                    )
         | 
| 704 | 
            +
                    latents = sampler(model_fn, latents, **extra_args)
         | 
| 705 | 
            +
             | 
| 706 | 
            +
                    if upscale:
         | 
| 707 | 
            +
                        target_height = height * upscale_x
         | 
| 708 | 
            +
                        target_width = width * upscale_x
         | 
| 709 | 
            +
                        vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
         | 
| 710 | 
            +
                        latents = torch.nn.functional.interpolate(
         | 
| 711 | 
            +
                            latents,
         | 
| 712 | 
            +
                            size=(
         | 
| 713 | 
            +
                                int(target_height // vae_scale_factor),
         | 
| 714 | 
            +
                                int(target_width // vae_scale_factor),
         | 
| 715 | 
            +
                            ),
         | 
| 716 | 
            +
                            mode=upscale_method,
         | 
| 717 | 
            +
                            antialias=upscale_antialias,
         | 
| 718 | 
            +
                        )
         | 
| 719 | 
            +
                        return self.img2img(
         | 
| 720 | 
            +
                            prompt=prompt,
         | 
| 721 | 
            +
                            num_inference_steps=num_inference_steps,
         | 
| 722 | 
            +
                            guidance_scale=guidance_scale,
         | 
| 723 | 
            +
                            negative_prompt=negative_prompt,
         | 
| 724 | 
            +
                            generator=generator,
         | 
| 725 | 
            +
                            latents=latents,
         | 
| 726 | 
            +
                            strength=upscale_denoising_strength,
         | 
| 727 | 
            +
                            sampler_name=sampler_name,
         | 
| 728 | 
            +
                            sampler_opt=sampler_opt,
         | 
| 729 | 
            +
                            pww_state=None,
         | 
| 730 | 
            +
                            pww_attn_weight=pww_attn_weight / 2,
         | 
| 731 | 
            +
                        )
         | 
| 732 | 
            +
             | 
| 733 | 
            +
                    # 8. Post-processing
         | 
| 734 | 
            +
                    image = self.decode_latents(latents)
         | 
| 735 | 
            +
             | 
| 736 | 
            +
                    # 10. Convert to PIL
         | 
| 737 | 
            +
                    if output_type == "pil":
         | 
| 738 | 
            +
                        image = self.numpy_to_pil(image)
         | 
| 739 | 
            +
             | 
| 740 | 
            +
                    return (image,)
         | 
| 741 | 
            +
             | 
| 742 | 
            +
             | 
| 743 | 
            +
            class FlashAttentionFunction(Function):
         | 
| 744 | 
            +
                @staticmethod
         | 
| 745 | 
            +
                @torch.no_grad()
         | 
| 746 | 
            +
                def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
         | 
| 747 | 
            +
                    """Algorithm 2 in the paper"""
         | 
| 748 | 
            +
             | 
| 749 | 
            +
                    device = q.device
         | 
| 750 | 
            +
                    max_neg_value = -torch.finfo(q.dtype).max
         | 
| 751 | 
            +
                    qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
         | 
| 752 | 
            +
             | 
| 753 | 
            +
                    o = torch.zeros_like(q)
         | 
| 754 | 
            +
                    all_row_sums = torch.zeros((*q.shape[:-1], 1), device=device)
         | 
| 755 | 
            +
                    all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, device=device)
         | 
| 756 | 
            +
             | 
| 757 | 
            +
                    scale = q.shape[-1] ** -0.5
         | 
| 758 | 
            +
             | 
| 759 | 
            +
                    if not exists(mask):
         | 
| 760 | 
            +
                        mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
         | 
| 761 | 
            +
                    else:
         | 
| 762 | 
            +
                        mask = rearrange(mask, "b n -> b 1 1 n")
         | 
| 763 | 
            +
                        mask = mask.split(q_bucket_size, dim=-1)
         | 
| 764 | 
            +
             | 
| 765 | 
            +
                    row_splits = zip(
         | 
| 766 | 
            +
                        q.split(q_bucket_size, dim=-2),
         | 
| 767 | 
            +
                        o.split(q_bucket_size, dim=-2),
         | 
| 768 | 
            +
                        mask,
         | 
| 769 | 
            +
                        all_row_sums.split(q_bucket_size, dim=-2),
         | 
| 770 | 
            +
                        all_row_maxes.split(q_bucket_size, dim=-2),
         | 
| 771 | 
            +
                    )
         | 
| 772 | 
            +
             | 
| 773 | 
            +
                    for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
         | 
| 774 | 
            +
                        q_start_index = ind * q_bucket_size - qk_len_diff
         | 
| 775 | 
            +
             | 
| 776 | 
            +
                        col_splits = zip(
         | 
| 777 | 
            +
                            k.split(k_bucket_size, dim=-2),
         | 
| 778 | 
            +
                            v.split(k_bucket_size, dim=-2),
         | 
| 779 | 
            +
                        )
         | 
| 780 | 
            +
             | 
| 781 | 
            +
                        for k_ind, (kc, vc) in enumerate(col_splits):
         | 
| 782 | 
            +
                            k_start_index = k_ind * k_bucket_size
         | 
| 783 | 
            +
             | 
| 784 | 
            +
                            attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale
         | 
| 785 | 
            +
             | 
| 786 | 
            +
                            if exists(row_mask):
         | 
| 787 | 
            +
                                attn_weights.masked_fill_(~row_mask, max_neg_value)
         | 
| 788 | 
            +
             | 
| 789 | 
            +
                            if causal and q_start_index < (k_start_index + k_bucket_size - 1):
         | 
| 790 | 
            +
                                causal_mask = torch.ones(
         | 
| 791 | 
            +
                                    (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
         | 
| 792 | 
            +
                                ).triu(q_start_index - k_start_index + 1)
         | 
| 793 | 
            +
                                attn_weights.masked_fill_(causal_mask, max_neg_value)
         | 
| 794 | 
            +
             | 
| 795 | 
            +
                            block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
         | 
| 796 | 
            +
                            attn_weights -= block_row_maxes
         | 
| 797 | 
            +
                            exp_weights = torch.exp(attn_weights)
         | 
| 798 | 
            +
             | 
| 799 | 
            +
                            if exists(row_mask):
         | 
| 800 | 
            +
                                exp_weights.masked_fill_(~row_mask, 0.0)
         | 
| 801 | 
            +
             | 
| 802 | 
            +
                            block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(
         | 
| 803 | 
            +
                                min=EPSILON
         | 
| 804 | 
            +
                            )
         | 
| 805 | 
            +
             | 
| 806 | 
            +
                            new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
         | 
| 807 | 
            +
             | 
| 808 | 
            +
                            exp_values = einsum("... i j, ... j d -> ... i d", exp_weights, vc)
         | 
| 809 | 
            +
             | 
| 810 | 
            +
                            exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
         | 
| 811 | 
            +
                            exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
         | 
| 812 | 
            +
             | 
| 813 | 
            +
                            new_row_sums = (
         | 
| 814 | 
            +
                                exp_row_max_diff * row_sums
         | 
| 815 | 
            +
                                + exp_block_row_max_diff * block_row_sums
         | 
| 816 | 
            +
                            )
         | 
| 817 | 
            +
             | 
| 818 | 
            +
                            oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
         | 
| 819 | 
            +
                                (exp_block_row_max_diff / new_row_sums) * exp_values
         | 
| 820 | 
            +
                            )
         | 
| 821 | 
            +
             | 
| 822 | 
            +
                            row_maxes.copy_(new_row_maxes)
         | 
| 823 | 
            +
                            row_sums.copy_(new_row_sums)
         | 
| 824 | 
            +
             | 
| 825 | 
            +
                    lse = all_row_sums.log() + all_row_maxes
         | 
| 826 | 
            +
             | 
| 827 | 
            +
                    ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
         | 
| 828 | 
            +
                    ctx.save_for_backward(q, k, v, o, lse)
         | 
| 829 | 
            +
             | 
| 830 | 
            +
                    return o
         | 
| 831 | 
            +
             | 
| 832 | 
            +
                @staticmethod
         | 
| 833 | 
            +
                @torch.no_grad()
         | 
| 834 | 
            +
                def backward(ctx, do):
         | 
| 835 | 
            +
                    """Algorithm 4 in the paper"""
         | 
| 836 | 
            +
             | 
| 837 | 
            +
                    causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
         | 
| 838 | 
            +
                    q, k, v, o, lse = ctx.saved_tensors
         | 
| 839 | 
            +
             | 
| 840 | 
            +
                    device = q.device
         | 
| 841 | 
            +
             | 
| 842 | 
            +
                    max_neg_value = -torch.finfo(q.dtype).max
         | 
| 843 | 
            +
                    qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
         | 
| 844 | 
            +
             | 
| 845 | 
            +
                    dq = torch.zeros_like(q)
         | 
| 846 | 
            +
                    dk = torch.zeros_like(k)
         | 
| 847 | 
            +
                    dv = torch.zeros_like(v)
         | 
| 848 | 
            +
             | 
| 849 | 
            +
                    row_splits = zip(
         | 
| 850 | 
            +
                        q.split(q_bucket_size, dim=-2),
         | 
| 851 | 
            +
                        o.split(q_bucket_size, dim=-2),
         | 
| 852 | 
            +
                        do.split(q_bucket_size, dim=-2),
         | 
| 853 | 
            +
                        mask,
         | 
| 854 | 
            +
                        lse.split(q_bucket_size, dim=-2),
         | 
| 855 | 
            +
                        dq.split(q_bucket_size, dim=-2),
         | 
| 856 | 
            +
                    )
         | 
| 857 | 
            +
             | 
| 858 | 
            +
                    for ind, (qc, oc, doc, row_mask, lsec, dqc) in enumerate(row_splits):
         | 
| 859 | 
            +
                        q_start_index = ind * q_bucket_size - qk_len_diff
         | 
| 860 | 
            +
             | 
| 861 | 
            +
                        col_splits = zip(
         | 
| 862 | 
            +
                            k.split(k_bucket_size, dim=-2),
         | 
| 863 | 
            +
                            v.split(k_bucket_size, dim=-2),
         | 
| 864 | 
            +
                            dk.split(k_bucket_size, dim=-2),
         | 
| 865 | 
            +
                            dv.split(k_bucket_size, dim=-2),
         | 
| 866 | 
            +
                        )
         | 
| 867 | 
            +
             | 
| 868 | 
            +
                        for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
         | 
| 869 | 
            +
                            k_start_index = k_ind * k_bucket_size
         | 
| 870 | 
            +
             | 
| 871 | 
            +
                            attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale
         | 
| 872 | 
            +
             | 
| 873 | 
            +
                            if causal and q_start_index < (k_start_index + k_bucket_size - 1):
         | 
| 874 | 
            +
                                causal_mask = torch.ones(
         | 
| 875 | 
            +
                                    (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
         | 
| 876 | 
            +
                                ).triu(q_start_index - k_start_index + 1)
         | 
| 877 | 
            +
                                attn_weights.masked_fill_(causal_mask, max_neg_value)
         | 
| 878 | 
            +
             | 
| 879 | 
            +
                            p = torch.exp(attn_weights - lsec)
         | 
| 880 | 
            +
             | 
| 881 | 
            +
                            if exists(row_mask):
         | 
| 882 | 
            +
                                p.masked_fill_(~row_mask, 0.0)
         | 
| 883 | 
            +
             | 
| 884 | 
            +
                            dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc)
         | 
| 885 | 
            +
                            dp = einsum("... i d, ... j d -> ... i j", doc, vc)
         | 
| 886 | 
            +
             | 
| 887 | 
            +
                            D = (doc * oc).sum(dim=-1, keepdims=True)
         | 
| 888 | 
            +
                            ds = p * scale * (dp - D)
         | 
| 889 | 
            +
             | 
| 890 | 
            +
                            dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc)
         | 
| 891 | 
            +
                            dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc)
         | 
| 892 | 
            +
             | 
| 893 | 
            +
                            dqc.add_(dq_chunk)
         | 
| 894 | 
            +
                            dkc.add_(dk_chunk)
         | 
| 895 | 
            +
                            dvc.add_(dv_chunk)
         | 
| 896 | 
            +
             | 
| 897 | 
            +
                    return dq, dk, dv, None, None, None, None
         | 
    	
        modules/prompt_parser.py
    ADDED
    
    | @@ -0,0 +1,391 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
             | 
| 2 | 
            +
            import re
         | 
| 3 | 
            +
            import math
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            # Code from https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/8e2aeee4a127b295bfc880800e4a312e0f049b85, modified.
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            class PromptChunk:
         | 
| 10 | 
            +
                """
         | 
| 11 | 
            +
                This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
         | 
| 12 | 
            +
                If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
         | 
| 13 | 
            +
                Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
         | 
| 14 | 
            +
                so just 75 tokens from prompt.
         | 
| 15 | 
            +
                """
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def __init__(self):
         | 
| 18 | 
            +
                    self.tokens = []
         | 
| 19 | 
            +
                    self.multipliers = []
         | 
| 20 | 
            +
                    self.fixes = []
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
         | 
| 24 | 
            +
                """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
         | 
| 25 | 
            +
                have unlimited prompt length and assign weights to tokens in prompt.
         | 
| 26 | 
            +
                """
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                def __init__(self, text_encoder, enable_emphasis=True):
         | 
| 29 | 
            +
                    super().__init__()
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    self.device = lambda: text_encoder.device
         | 
| 32 | 
            +
                    self.enable_emphasis = enable_emphasis
         | 
| 33 | 
            +
                    """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
         | 
| 34 | 
            +
                    depending on model."""
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    self.chunk_length = 75
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                def empty_chunk(self):
         | 
| 39 | 
            +
                    """creates an empty PromptChunk and returns it"""
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    chunk = PromptChunk()
         | 
| 42 | 
            +
                    chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
         | 
| 43 | 
            +
                    chunk.multipliers = [1.0] * (self.chunk_length + 2)
         | 
| 44 | 
            +
                    return chunk
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                def get_target_prompt_token_count(self, token_count):
         | 
| 47 | 
            +
                    """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                def tokenize_line(self, line):
         | 
| 52 | 
            +
                    """
         | 
| 53 | 
            +
                    this transforms a single prompt into a list of PromptChunk objects - as many as needed to
         | 
| 54 | 
            +
                    represent the prompt.
         | 
| 55 | 
            +
                    Returns the list and the total number of tokens in the prompt.
         | 
| 56 | 
            +
                    """
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    if self.enable_emphasis:
         | 
| 59 | 
            +
                        parsed = parse_prompt_attention(line)
         | 
| 60 | 
            +
                    else:
         | 
| 61 | 
            +
                        parsed = [[line, 1.0]]
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    tokenized = self.tokenize([text for text, _ in parsed])
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    chunks = []
         | 
| 66 | 
            +
                    chunk = PromptChunk()
         | 
| 67 | 
            +
                    token_count = 0
         | 
| 68 | 
            +
                    last_comma = -1
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    def next_chunk(is_last=False):
         | 
| 71 | 
            +
                        """puts current chunk into the list of results and produces the next one - empty;
         | 
| 72 | 
            +
                        if is_last is true, tokens <end-of-text> tokens at the end won't add to token_count"""
         | 
| 73 | 
            +
                        nonlocal token_count
         | 
| 74 | 
            +
                        nonlocal last_comma
         | 
| 75 | 
            +
                        nonlocal chunk
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                        if is_last:
         | 
| 78 | 
            +
                            token_count += len(chunk.tokens)
         | 
| 79 | 
            +
                        else:
         | 
| 80 | 
            +
                            token_count += self.chunk_length
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                        to_add = self.chunk_length - len(chunk.tokens)
         | 
| 83 | 
            +
                        if to_add > 0:
         | 
| 84 | 
            +
                            chunk.tokens += [self.id_end] * to_add
         | 
| 85 | 
            +
                            chunk.multipliers += [1.0] * to_add
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                        chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
         | 
| 88 | 
            +
                        chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                        last_comma = -1
         | 
| 91 | 
            +
                        chunks.append(chunk)
         | 
| 92 | 
            +
                        chunk = PromptChunk()
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    comma_padding_backtrack = 20  # default value in https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/6cff4401824299a983c8e13424018efc347b4a2b/modules/shared.py#L410
         | 
| 95 | 
            +
                    for tokens, (text, weight) in zip(tokenized, parsed):
         | 
| 96 | 
            +
                        if text == "BREAK" and weight == -1:
         | 
| 97 | 
            +
                            next_chunk()
         | 
| 98 | 
            +
                            continue
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                        position = 0
         | 
| 101 | 
            +
                        while position < len(tokens):
         | 
| 102 | 
            +
                            token = tokens[position]
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                            if token == self.comma_token:
         | 
| 105 | 
            +
                                last_comma = len(chunk.tokens)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                            # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
         | 
| 108 | 
            +
                            # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
         | 
| 109 | 
            +
                            elif (
         | 
| 110 | 
            +
                                comma_padding_backtrack != 0
         | 
| 111 | 
            +
                                and len(chunk.tokens) == self.chunk_length
         | 
| 112 | 
            +
                                and last_comma != -1
         | 
| 113 | 
            +
                                and len(chunk.tokens) - last_comma <= comma_padding_backtrack
         | 
| 114 | 
            +
                            ):
         | 
| 115 | 
            +
                                break_location = last_comma + 1
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                                reloc_tokens = chunk.tokens[break_location:]
         | 
| 118 | 
            +
                                reloc_mults = chunk.multipliers[break_location:]
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                                chunk.tokens = chunk.tokens[:break_location]
         | 
| 121 | 
            +
                                chunk.multipliers = chunk.multipliers[:break_location]
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                                next_chunk()
         | 
| 124 | 
            +
                                chunk.tokens = reloc_tokens
         | 
| 125 | 
            +
                                chunk.multipliers = reloc_mults
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                            if len(chunk.tokens) == self.chunk_length:
         | 
| 128 | 
            +
                                next_chunk()
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                            chunk.tokens.append(token)
         | 
| 131 | 
            +
                            chunk.multipliers.append(weight)
         | 
| 132 | 
            +
                            position += 1
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    if len(chunk.tokens) > 0 or len(chunks) == 0:
         | 
| 135 | 
            +
                        next_chunk(is_last=True)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    return chunks, token_count
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                def process_texts(self, texts):
         | 
| 140 | 
            +
                    """
         | 
| 141 | 
            +
                    Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
         | 
| 142 | 
            +
                    length, in tokens, of all texts.
         | 
| 143 | 
            +
                    """
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    token_count = 0
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    cache = {}
         | 
| 148 | 
            +
                    batch_chunks = []
         | 
| 149 | 
            +
                    for line in texts:
         | 
| 150 | 
            +
                        if line in cache:
         | 
| 151 | 
            +
                            chunks = cache[line]
         | 
| 152 | 
            +
                        else:
         | 
| 153 | 
            +
                            chunks, current_token_count = self.tokenize_line(line)
         | 
| 154 | 
            +
                            token_count = max(current_token_count, token_count)
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                            cache[line] = chunks
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                        batch_chunks.append(chunks)
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    return batch_chunks, token_count
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                def forward(self, texts):
         | 
| 163 | 
            +
                    """
         | 
| 164 | 
            +
                    Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
         | 
| 165 | 
            +
                    Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
         | 
| 166 | 
            +
                    be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
         | 
| 167 | 
            +
                    An example shape returned by this function can be: (2, 77, 768).
         | 
| 168 | 
            +
                    Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
         | 
| 169 | 
            +
                    is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
         | 
| 170 | 
            +
                    """
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    batch_chunks, token_count = self.process_texts(texts)
         | 
| 173 | 
            +
                    chunk_count = max([len(x) for x in batch_chunks])
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    zs = []
         | 
| 176 | 
            +
                    ts = []
         | 
| 177 | 
            +
                    for i in range(chunk_count):
         | 
| 178 | 
            +
                        batch_chunk = [
         | 
| 179 | 
            +
                            chunks[i] if i < len(chunks) else self.empty_chunk()
         | 
| 180 | 
            +
                            for chunks in batch_chunks
         | 
| 181 | 
            +
                        ]
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                        tokens = [x.tokens for x in batch_chunk]
         | 
| 184 | 
            +
                        multipliers = [x.multipliers for x in batch_chunk]
         | 
| 185 | 
            +
                        # self.embeddings.fixes = [x.fixes for x in batch_chunk]
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                        # for fixes in self.embeddings.fixes:
         | 
| 188 | 
            +
                        #     for position, embedding in fixes:
         | 
| 189 | 
            +
                        #         used_embeddings[embedding.name] = embedding
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                        z = self.process_tokens(tokens, multipliers)
         | 
| 192 | 
            +
                        zs.append(z)
         | 
| 193 | 
            +
                        ts.append(tokens)
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    return np.hstack(ts), torch.hstack(zs)
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                def process_tokens(self, remade_batch_tokens, batch_multipliers):
         | 
| 198 | 
            +
                    """
         | 
| 199 | 
            +
                    sends one single prompt chunk to be encoded by transformers neural network.
         | 
| 200 | 
            +
                    remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
         | 
| 201 | 
            +
                    there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
         | 
| 202 | 
            +
                    Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
         | 
| 203 | 
            +
                    corresponds to one token.
         | 
| 204 | 
            +
                    """
         | 
| 205 | 
            +
                    tokens = torch.asarray(remade_batch_tokens).to(self.device())
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
         | 
| 208 | 
            +
                    if self.id_end != self.id_pad:
         | 
| 209 | 
            +
                        for batch_pos in range(len(remade_batch_tokens)):
         | 
| 210 | 
            +
                            index = remade_batch_tokens[batch_pos].index(self.id_end)
         | 
| 211 | 
            +
                            tokens[batch_pos, index + 1 : tokens.shape[1]] = self.id_pad
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    z = self.encode_with_transformers(tokens)
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
         | 
| 216 | 
            +
                    batch_multipliers = torch.asarray(batch_multipliers).to(self.device())
         | 
| 217 | 
            +
                    original_mean = z.mean()
         | 
| 218 | 
            +
                    z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
         | 
| 219 | 
            +
                    new_mean = z.mean()
         | 
| 220 | 
            +
                    z = z * (original_mean / new_mean)
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                    return z
         | 
| 223 | 
            +
             | 
| 224 | 
            +
             | 
| 225 | 
            +
            class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
         | 
| 226 | 
            +
                def __init__(self, tokenizer, text_encoder):
         | 
| 227 | 
            +
                    super().__init__(text_encoder)
         | 
| 228 | 
            +
                    self.tokenizer = tokenizer
         | 
| 229 | 
            +
                    self.text_encoder = text_encoder
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                    vocab = self.tokenizer.get_vocab()
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    self.comma_token = vocab.get(",</w>", None)
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                    self.token_mults = {}
         | 
| 236 | 
            +
                    tokens_with_parens = [
         | 
| 237 | 
            +
                        (k, v)
         | 
| 238 | 
            +
                        for k, v in vocab.items()
         | 
| 239 | 
            +
                        if "(" in k or ")" in k or "[" in k or "]" in k
         | 
| 240 | 
            +
                    ]
         | 
| 241 | 
            +
                    for text, ident in tokens_with_parens:
         | 
| 242 | 
            +
                        mult = 1.0
         | 
| 243 | 
            +
                        for c in text:
         | 
| 244 | 
            +
                            if c == "[":
         | 
| 245 | 
            +
                                mult /= 1.1
         | 
| 246 | 
            +
                            if c == "]":
         | 
| 247 | 
            +
                                mult *= 1.1
         | 
| 248 | 
            +
                            if c == "(":
         | 
| 249 | 
            +
                                mult *= 1.1
         | 
| 250 | 
            +
                            if c == ")":
         | 
| 251 | 
            +
                                mult /= 1.1
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                        if mult != 1.0:
         | 
| 254 | 
            +
                            self.token_mults[ident] = mult
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    self.id_start = self.tokenizer.bos_token_id
         | 
| 257 | 
            +
                    self.id_end = self.tokenizer.eos_token_id
         | 
| 258 | 
            +
                    self.id_pad = self.id_end
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                def tokenize(self, texts):
         | 
| 261 | 
            +
                    tokenized = self.tokenizer(
         | 
| 262 | 
            +
                        texts, truncation=False, add_special_tokens=False
         | 
| 263 | 
            +
                    )["input_ids"]
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                    return tokenized
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                def encode_with_transformers(self, tokens):
         | 
| 268 | 
            +
                    CLIP_stop_at_last_layers = 1
         | 
| 269 | 
            +
                    tokens = tokens.to(self.text_encoder.device)
         | 
| 270 | 
            +
                    outputs = self.text_encoder(tokens, output_hidden_states=True)
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                    if CLIP_stop_at_last_layers > 1:
         | 
| 273 | 
            +
                        z = outputs.hidden_states[-CLIP_stop_at_last_layers]
         | 
| 274 | 
            +
                        z = self.text_encoder.text_model.final_layer_norm(z)
         | 
| 275 | 
            +
                    else:
         | 
| 276 | 
            +
                        z = outputs.last_hidden_state
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                    return z
         | 
| 279 | 
            +
                
         | 
| 280 | 
            +
             | 
| 281 | 
            +
            re_attention = re.compile(
         | 
| 282 | 
            +
                r"""
         | 
| 283 | 
            +
            \\\(|
         | 
| 284 | 
            +
            \\\)|
         | 
| 285 | 
            +
            \\\[|
         | 
| 286 | 
            +
            \\]|
         | 
| 287 | 
            +
            \\\\|
         | 
| 288 | 
            +
            \\|
         | 
| 289 | 
            +
            \(|
         | 
| 290 | 
            +
            \[|
         | 
| 291 | 
            +
            :([+-]?[.\d]+)\)|
         | 
| 292 | 
            +
            \)|
         | 
| 293 | 
            +
            ]|
         | 
| 294 | 
            +
            [^\\()\[\]:]+|
         | 
| 295 | 
            +
            :
         | 
| 296 | 
            +
            """,
         | 
| 297 | 
            +
                re.X,
         | 
| 298 | 
            +
            )
         | 
| 299 | 
            +
             | 
| 300 | 
            +
            re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
         | 
| 301 | 
            +
             | 
| 302 | 
            +
             | 
| 303 | 
            +
            def parse_prompt_attention(text):
         | 
| 304 | 
            +
                """
         | 
| 305 | 
            +
                Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
         | 
| 306 | 
            +
                Accepted tokens are:
         | 
| 307 | 
            +
                  (abc) - increases attention to abc by a multiplier of 1.1
         | 
| 308 | 
            +
                  (abc:3.12) - increases attention to abc by a multiplier of 3.12
         | 
| 309 | 
            +
                  [abc] - decreases attention to abc by a multiplier of 1.1
         | 
| 310 | 
            +
                  \( - literal character '('
         | 
| 311 | 
            +
                  \[ - literal character '['
         | 
| 312 | 
            +
                  \) - literal character ')'
         | 
| 313 | 
            +
                  \] - literal character ']'
         | 
| 314 | 
            +
                  \\ - literal character '\'
         | 
| 315 | 
            +
                  anything else - just text
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                >>> parse_prompt_attention('normal text')
         | 
| 318 | 
            +
                [['normal text', 1.0]]
         | 
| 319 | 
            +
                >>> parse_prompt_attention('an (important) word')
         | 
| 320 | 
            +
                [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
         | 
| 321 | 
            +
                >>> parse_prompt_attention('(unbalanced')
         | 
| 322 | 
            +
                [['unbalanced', 1.1]]
         | 
| 323 | 
            +
                >>> parse_prompt_attention('\(literal\]')
         | 
| 324 | 
            +
                [['(literal]', 1.0]]
         | 
| 325 | 
            +
                >>> parse_prompt_attention('(unnecessary)(parens)')
         | 
| 326 | 
            +
                [['unnecessaryparens', 1.1]]
         | 
| 327 | 
            +
                >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
         | 
| 328 | 
            +
                [['a ', 1.0],
         | 
| 329 | 
            +
                 ['house', 1.5730000000000004],
         | 
| 330 | 
            +
                 [' ', 1.1],
         | 
| 331 | 
            +
                 ['on', 1.0],
         | 
| 332 | 
            +
                 [' a ', 1.1],
         | 
| 333 | 
            +
                 ['hill', 0.55],
         | 
| 334 | 
            +
                 [', sun, ', 1.1],
         | 
| 335 | 
            +
                 ['sky', 1.4641000000000006],
         | 
| 336 | 
            +
                 ['.', 1.1]]
         | 
| 337 | 
            +
                """
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                res = []
         | 
| 340 | 
            +
                round_brackets = []
         | 
| 341 | 
            +
                square_brackets = []
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                round_bracket_multiplier = 1.1
         | 
| 344 | 
            +
                square_bracket_multiplier = 1 / 1.1
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                def multiply_range(start_position, multiplier):
         | 
| 347 | 
            +
                    for p in range(start_position, len(res)):
         | 
| 348 | 
            +
                        res[p][1] *= multiplier
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                for m in re_attention.finditer(text):
         | 
| 351 | 
            +
                    text = m.group(0)
         | 
| 352 | 
            +
                    weight = m.group(1)
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                    if text.startswith("\\"):
         | 
| 355 | 
            +
                        res.append([text[1:], 1.0])
         | 
| 356 | 
            +
                    elif text == "(":
         | 
| 357 | 
            +
                        round_brackets.append(len(res))
         | 
| 358 | 
            +
                    elif text == "[":
         | 
| 359 | 
            +
                        square_brackets.append(len(res))
         | 
| 360 | 
            +
                    elif weight is not None and len(round_brackets) > 0:
         | 
| 361 | 
            +
                        multiply_range(round_brackets.pop(), float(weight))
         | 
| 362 | 
            +
                    elif text == ")" and len(round_brackets) > 0:
         | 
| 363 | 
            +
                        multiply_range(round_brackets.pop(), round_bracket_multiplier)
         | 
| 364 | 
            +
                    elif text == "]" and len(square_brackets) > 0:
         | 
| 365 | 
            +
                        multiply_range(square_brackets.pop(), square_bracket_multiplier)
         | 
| 366 | 
            +
                    else:
         | 
| 367 | 
            +
                        parts = re.split(re_break, text)
         | 
| 368 | 
            +
                        for i, part in enumerate(parts):
         | 
| 369 | 
            +
                            if i > 0:
         | 
| 370 | 
            +
                                res.append(["BREAK", -1])
         | 
| 371 | 
            +
                            res.append([part, 1.0])
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                for pos in round_brackets:
         | 
| 374 | 
            +
                    multiply_range(pos, round_bracket_multiplier)
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                for pos in square_brackets:
         | 
| 377 | 
            +
                    multiply_range(pos, square_bracket_multiplier)
         | 
| 378 | 
            +
             | 
| 379 | 
            +
                if len(res) == 0:
         | 
| 380 | 
            +
                    res = [["", 1.0]]
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                # merge runs of identical weights
         | 
| 383 | 
            +
                i = 0
         | 
| 384 | 
            +
                while i + 1 < len(res):
         | 
| 385 | 
            +
                    if res[i][1] == res[i + 1][1]:
         | 
| 386 | 
            +
                        res[i][0] += res[i + 1][0]
         | 
| 387 | 
            +
                        res.pop(i + 1)
         | 
| 388 | 
            +
                    else:
         | 
| 389 | 
            +
                        i += 1
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                return res
         | 
    	
        modules/safe.py
    ADDED
    
    | @@ -0,0 +1,188 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # this code is adapted from the script contributed by anon from /h/
         | 
| 2 | 
            +
            # modified, from https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/6cff4401824299a983c8e13424018efc347b4a2b/modules/safe.py
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import io
         | 
| 5 | 
            +
            import pickle
         | 
| 6 | 
            +
            import collections
         | 
| 7 | 
            +
            import sys
         | 
| 8 | 
            +
            import traceback
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            import numpy
         | 
| 12 | 
            +
            import _codecs
         | 
| 13 | 
            +
            import zipfile
         | 
| 14 | 
            +
            import re
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
         | 
| 18 | 
            +
            TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            def encode(*args):
         | 
| 22 | 
            +
                out = _codecs.encode(*args)
         | 
| 23 | 
            +
                return out
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            class RestrictedUnpickler(pickle.Unpickler):
         | 
| 27 | 
            +
                extra_handler = None
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                def persistent_load(self, saved_id):
         | 
| 30 | 
            +
                    assert saved_id[0] == 'storage'
         | 
| 31 | 
            +
                    return TypedStorage()
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                def find_class(self, module, name):
         | 
| 34 | 
            +
                    if self.extra_handler is not None:
         | 
| 35 | 
            +
                        res = self.extra_handler(module, name)
         | 
| 36 | 
            +
                        if res is not None:
         | 
| 37 | 
            +
                            return res
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    if module == 'collections' and name == 'OrderedDict':
         | 
| 40 | 
            +
                        return getattr(collections, name)
         | 
| 41 | 
            +
                    if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
         | 
| 42 | 
            +
                        return getattr(torch._utils, name)
         | 
| 43 | 
            +
                    if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
         | 
| 44 | 
            +
                        return getattr(torch, name)
         | 
| 45 | 
            +
                    if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
         | 
| 46 | 
            +
                        return getattr(torch.nn.modules.container, name)
         | 
| 47 | 
            +
                    if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
         | 
| 48 | 
            +
                        return getattr(numpy.core.multiarray, name)
         | 
| 49 | 
            +
                    if module == 'numpy' and name in ['dtype', 'ndarray']:
         | 
| 50 | 
            +
                        return getattr(numpy, name)
         | 
| 51 | 
            +
                    if module == '_codecs' and name == 'encode':
         | 
| 52 | 
            +
                        return encode
         | 
| 53 | 
            +
                    if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
         | 
| 54 | 
            +
                        import pytorch_lightning.callbacks
         | 
| 55 | 
            +
                        return pytorch_lightning.callbacks.model_checkpoint
         | 
| 56 | 
            +
                    if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
         | 
| 57 | 
            +
                        import pytorch_lightning.callbacks.model_checkpoint
         | 
| 58 | 
            +
                        return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
         | 
| 59 | 
            +
                    if module == "__builtin__" and name == 'set':
         | 
| 60 | 
            +
                        return set
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    # Forbid everything else.
         | 
| 63 | 
            +
                    raise Exception(f"global '{module}/{name}' is forbidden")
         | 
| 64 | 
            +
             | 
| 65 | 
            +
             | 
| 66 | 
            +
            # Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
         | 
| 67 | 
            +
            allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
         | 
| 68 | 
            +
            data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            def check_zip_filenames(filename, names):
         | 
| 71 | 
            +
                for name in names:
         | 
| 72 | 
            +
                    if allowed_zip_names_re.match(name):
         | 
| 73 | 
            +
                        continue
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    raise Exception(f"bad file inside {filename}: {name}")
         | 
| 76 | 
            +
             | 
| 77 | 
            +
             | 
| 78 | 
            +
            def check_pt(filename, extra_handler):
         | 
| 79 | 
            +
                try:
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    # new pytorch format is a zip file
         | 
| 82 | 
            +
                    with zipfile.ZipFile(filename) as z:
         | 
| 83 | 
            +
                        check_zip_filenames(filename, z.namelist())
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                        # find filename of data.pkl in zip file: '<directory name>/data.pkl'
         | 
| 86 | 
            +
                        data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
         | 
| 87 | 
            +
                        if len(data_pkl_filenames) == 0:
         | 
| 88 | 
            +
                            raise Exception(f"data.pkl not found in {filename}")
         | 
| 89 | 
            +
                        if len(data_pkl_filenames) > 1:
         | 
| 90 | 
            +
                            raise Exception(f"Multiple data.pkl found in {filename}")
         | 
| 91 | 
            +
                        with z.open(data_pkl_filenames[0]) as file:
         | 
| 92 | 
            +
                            unpickler = RestrictedUnpickler(file)
         | 
| 93 | 
            +
                            unpickler.extra_handler = extra_handler
         | 
| 94 | 
            +
                            unpickler.load()
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                except zipfile.BadZipfile:
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
         | 
| 99 | 
            +
                    with open(filename, "rb") as file:
         | 
| 100 | 
            +
                        unpickler = RestrictedUnpickler(file)
         | 
| 101 | 
            +
                        unpickler.extra_handler = extra_handler
         | 
| 102 | 
            +
                        for i in range(5):
         | 
| 103 | 
            +
                            unpickler.load()
         | 
| 104 | 
            +
             | 
| 105 | 
            +
             | 
| 106 | 
            +
            def load(filename, *args, **kwargs):
         | 
| 107 | 
            +
                return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
             | 
| 110 | 
            +
            def load_with_extra(filename, extra_handler=None, *args, **kwargs):
         | 
| 111 | 
            +
                """
         | 
| 112 | 
            +
                this function is intended to be used by extensions that want to load models with
         | 
| 113 | 
            +
                some extra classes in them that the usual unpickler would find suspicious.
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                Use the extra_handler argument to specify a function that takes module and field name as text,
         | 
| 116 | 
            +
                and returns that field's value:
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                ```python
         | 
| 119 | 
            +
                def extra(module, name):
         | 
| 120 | 
            +
                    if module == 'collections' and name == 'OrderedDict':
         | 
| 121 | 
            +
                        return collections.OrderedDict
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    return None
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                safe.load_with_extra('model.pt', extra_handler=extra)
         | 
| 126 | 
            +
                ```
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
         | 
| 129 | 
            +
                definitely unsafe.
         | 
| 130 | 
            +
                """
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                try:
         | 
| 133 | 
            +
                    check_pt(filename, extra_handler)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                except pickle.UnpicklingError:
         | 
| 136 | 
            +
                    print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
         | 
| 137 | 
            +
                    print(traceback.format_exc(), file=sys.stderr)
         | 
| 138 | 
            +
                    print("The file is most likely corrupted.", file=sys.stderr)
         | 
| 139 | 
            +
                    return None
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                except Exception:
         | 
| 142 | 
            +
                    print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
         | 
| 143 | 
            +
                    print(traceback.format_exc(), file=sys.stderr)
         | 
| 144 | 
            +
                    print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
         | 
| 145 | 
            +
                    print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
         | 
| 146 | 
            +
                    return None
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                return unsafe_torch_load(filename, *args, **kwargs)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
             | 
| 151 | 
            +
            class Extra:
         | 
| 152 | 
            +
                """
         | 
| 153 | 
            +
                A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
         | 
| 154 | 
            +
                (because it's not your code making the torch.load call). The intended use is like this:
         | 
| 155 | 
            +
             | 
| 156 | 
            +
            ```
         | 
| 157 | 
            +
            import torch
         | 
| 158 | 
            +
            from modules import safe
         | 
| 159 | 
            +
             | 
| 160 | 
            +
            def handler(module, name):
         | 
| 161 | 
            +
                if module == 'torch' and name in ['float64', 'float16']:
         | 
| 162 | 
            +
                    return getattr(torch, name)
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                return None
         | 
| 165 | 
            +
             | 
| 166 | 
            +
            with safe.Extra(handler):
         | 
| 167 | 
            +
                x = torch.load('model.pt')
         | 
| 168 | 
            +
            ```
         | 
| 169 | 
            +
                """
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                def __init__(self, handler):
         | 
| 172 | 
            +
                    self.handler = handler
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                def __enter__(self):
         | 
| 175 | 
            +
                    global global_extra_handler
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    assert global_extra_handler is None, 'already inside an Extra() block'
         | 
| 178 | 
            +
                    global_extra_handler = self.handler
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                def __exit__(self, exc_type, exc_val, exc_tb):
         | 
| 181 | 
            +
                    global global_extra_handler
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    global_extra_handler = None
         | 
| 184 | 
            +
             | 
| 185 | 
            +
             | 
| 186 | 
            +
            unsafe_torch_load = torch.load
         | 
| 187 | 
            +
            torch.load = load
         | 
| 188 | 
            +
            global_extra_handler = None
         | 
