Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	use pruna for quantization
Browse files- server/config.py +7 -0
 - server/pipelines/controlnet.py +10 -1
 - server/pipelines/img2imgFlux.py +101 -35
 - server/pipelines/img2imgSDTurbo.py +10 -2
 - server/pipelines/img2imgSDXL-Lightning.py +8 -0
 - server/pipelines/img2imgSDXLTurbo.py +19 -3
 - server/requirements.txt +18 -13
 
    	
        server/config.py
    CHANGED
    
    | 
         @@ -20,6 +20,7 @@ class Args(BaseModel): 
     | 
|
| 20 | 
         
             
                onediff: bool = False
         
     | 
| 21 | 
         
             
                compel: bool = False
         
     | 
| 22 | 
         
             
                debug: bool = False
         
     | 
| 
         | 
|
| 23 | 
         | 
| 24 | 
         
             
                def pretty_print(self) -> None:
         
     | 
| 25 | 
         
             
                    print("\n")
         
     | 
| 
         @@ -123,6 +124,12 @@ parser.add_argument( 
     | 
|
| 123 | 
         
             
                default=False,
         
     | 
| 124 | 
         
             
                help="Enable OneDiff",
         
     | 
| 125 | 
         
             
            )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 126 | 
         
             
            parser.set_defaults(taesd=USE_TAESD)
         
     | 
| 127 | 
         | 
| 128 | 
         
             
            config = Args.model_validate(vars(parser.parse_args()))
         
     | 
| 
         | 
|
| 20 | 
         
             
                onediff: bool = False
         
     | 
| 21 | 
         
             
                compel: bool = False
         
     | 
| 22 | 
         
             
                debug: bool = False
         
     | 
| 23 | 
         
            +
                pruna: bool = False
         
     | 
| 24 | 
         | 
| 25 | 
         
             
                def pretty_print(self) -> None:
         
     | 
| 26 | 
         
             
                    print("\n")
         
     | 
| 
         | 
|
| 124 | 
         
             
                default=False,
         
     | 
| 125 | 
         
             
                help="Enable OneDiff",
         
     | 
| 126 | 
         
             
            )
         
     | 
| 127 | 
         
            +
            parser.add_argument(
         
     | 
| 128 | 
         
            +
                "--pruna",
         
     | 
| 129 | 
         
            +
                action="store_true",
         
     | 
| 130 | 
         
            +
                default=False,
         
     | 
| 131 | 
         
            +
                help="Enable Pruna",
         
     | 
| 132 | 
         
            +
            )
         
     | 
| 133 | 
         
             
            parser.set_defaults(taesd=USE_TAESD)
         
     | 
| 134 | 
         | 
| 135 | 
         
             
            config = Args.model_validate(vars(parser.parse_args()))
         
     | 
    	
        server/pipelines/controlnet.py
    CHANGED
    
    | 
         @@ -17,6 +17,8 @@ from config import Args 
     | 
|
| 17 | 
         
             
            from pydantic import BaseModel, Field
         
     | 
| 18 | 
         
             
            from PIL import Image
         
     | 
| 19 | 
         
             
            import math
         
     | 
| 
         | 
|
| 
         | 
|
| 20 | 
         | 
| 21 | 
         
             
            base_model = "SimianLuo/LCM_Dreamshaper_v7"
         
     | 
| 22 | 
         
             
            taesd_model = "madebyollin/taesd"
         
     | 
| 
         @@ -58,7 +60,7 @@ class Pipeline: 
     | 
|
| 58 | 
         
             
                    input_mode: str = "image"
         
     | 
| 59 | 
         
             
                    page_content: str = page_content
         
     | 
| 60 | 
         | 
| 61 | 
         
            -
                class InputParams( 
     | 
| 62 | 
         
             
                    prompt: str = Field(
         
     | 
| 63 | 
         
             
                        default_prompt,
         
     | 
| 64 | 
         
             
                        title="Prompt",
         
     | 
| 
         @@ -170,6 +172,13 @@ class Pipeline: 
     | 
|
| 170 | 
         
             
                            taesd_model, torch_dtype=torch_dtype, use_safetensors=True
         
     | 
| 171 | 
         
             
                        ).to(device)
         
     | 
| 172 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 173 | 
         
             
                    if args.sfast:
         
     | 
| 174 | 
         
             
                        print("\nRunning sfast compile\n")
         
     | 
| 175 | 
         
             
                        from sfast.compilers.stable_diffusion_pipeline_compiler import (
         
     | 
| 
         | 
|
| 17 | 
         
             
            from pydantic import BaseModel, Field
         
     | 
| 18 | 
         
             
            from PIL import Image
         
     | 
| 19 | 
         
             
            import math
         
     | 
| 20 | 
         
            +
            from pruna import SmashConfig, smash
         
     | 
| 21 | 
         
            +
            from util import ParamsModel
         
     | 
| 22 | 
         | 
| 23 | 
         
             
            base_model = "SimianLuo/LCM_Dreamshaper_v7"
         
     | 
| 24 | 
         
             
            taesd_model = "madebyollin/taesd"
         
     | 
| 
         | 
|
| 60 | 
         
             
                    input_mode: str = "image"
         
     | 
| 61 | 
         
             
                    page_content: str = page_content
         
     | 
| 62 | 
         | 
| 63 | 
         
            +
                class InputParams(ParamsModel):
         
     | 
| 64 | 
         
             
                    prompt: str = Field(
         
     | 
| 65 | 
         
             
                        default_prompt,
         
     | 
| 66 | 
         
             
                        title="Prompt",
         
     | 
| 
         | 
|
| 172 | 
         
             
                            taesd_model, torch_dtype=torch_dtype, use_safetensors=True
         
     | 
| 173 | 
         
             
                        ).to(device)
         
     | 
| 174 | 
         | 
| 175 | 
         
            +
                    if args.pruna:
         
     | 
| 176 | 
         
            +
                        # Create and smash your model
         
     | 
| 177 | 
         
            +
                        smash_config = SmashConfig()
         
     | 
| 178 | 
         
            +
                        smash_config["cacher"] = "deepcache"
         
     | 
| 179 | 
         
            +
                        smash_config["compiler"] = "stable_fast"
         
     | 
| 180 | 
         
            +
                        self.pipe = smash(model=self.pipe, smash_config=smash_config)
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
             
                    if args.sfast:
         
     | 
| 183 | 
         
             
                        print("\nRunning sfast compile\n")
         
     | 
| 184 | 
         
             
                        from sfast.compilers.stable_diffusion_pipeline_compiler import (
         
     | 
    	
        server/pipelines/img2imgFlux.py
    CHANGED
    
    | 
         @@ -2,21 +2,19 @@ import torch 
     | 
|
| 2 | 
         | 
| 3 | 
         
             
            from optimum.quanto import freeze, qfloat8, quantize
         
     | 
| 4 | 
         
             
            from transformers.modeling_utils import PreTrainedModel
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 5 | 
         | 
| 6 | 
         
            -
            from diffusers import (
         
     | 
| 7 | 
         
            -
                FlowMatchEulerDiscreteScheduler,
         
     | 
| 8 | 
         
            -
                AutoencoderKL,
         
     | 
| 9 | 
         
            -
                AutoencoderTiny,
         
     | 
| 10 | 
         
            -
                FluxImg2ImgPipeline,
         
     | 
| 11 | 
         
            -
                FluxPipeline,
         
     | 
| 12 | 
         
            -
            )
         
     | 
| 13 | 
         
            -
             
     | 
| 14 | 
         
            -
            from diffusers import (
         
     | 
| 15 | 
         
            -
                FluxImg2ImgPipeline,
         
     | 
| 16 | 
         
            -
                FluxPipeline,
         
     | 
| 17 | 
         
            -
                FluxTransformer2DModel,
         
     | 
| 18 | 
         
            -
                GGUFQuantizationConfig,
         
     | 
| 19 | 
         
            -
            )
         
     | 
| 20 | 
         | 
| 21 | 
         
             
            try:
         
     | 
| 22 | 
         
             
                import intel_extension_for_pytorch as ipex  # type: ignore
         
     | 
| 
         @@ -76,10 +74,10 @@ class Pipeline: 
     | 
|
| 76 | 
         
             
                        1, min=1, max=15, title="Steps", field="range", hide=True, id="steps"
         
     | 
| 77 | 
         
             
                    )
         
     | 
| 78 | 
         
             
                    width: int = Field(
         
     | 
| 79 | 
         
            -
                         
     | 
| 80 | 
         
             
                    )
         
     | 
| 81 | 
         
             
                    height: int = Field(
         
     | 
| 82 | 
         
            -
                         
     | 
| 83 | 
         
             
                    )
         
     | 
| 84 | 
         
             
                    strength: float = Field(
         
     | 
| 85 | 
         
             
                        0.5,
         
     | 
| 
         @@ -107,33 +105,101 @@ class Pipeline: 
     | 
|
| 107 | 
         
             
                    #     "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
         
     | 
| 108 | 
         
             
                    # )
         
     | 
| 109 | 
         
             
                    print("Loading model")
         
     | 
| 110 | 
         
            -
             
     | 
| 111 | 
         
            -
                     
     | 
| 112 | 
         
            -
                     
     | 
| 113 | 
         
            -
             
     | 
| 114 | 
         
            -
             
     | 
| 115 | 
         
            -
             
     | 
| 116 | 
         
            -
             
     | 
| 117 | 
         
            -
             
     | 
| 118 | 
         
            -
                     
     | 
| 119 | 
         
            -
             
     | 
| 120 | 
         
            -
             
     | 
| 121 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 122 | 
         
             
                        transformer=transformer,
         
     | 
| 123 | 
         
            -
                        torch_dtype=torch.bfloat16,
         
     | 
| 124 | 
         
             
                    )
         
     | 
| 125 | 
         
            -
             
     | 
| 126 | 
         
            -
             
     | 
| 127 | 
         
            -
             
     | 
| 128 | 
         
            -
             
     | 
| 
         | 
|
| 129 | 
         
             
                    # pipe.enable_model_cpu_offload()
         
     | 
| 130 | 
         
            -
                    pipe 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 131 | 
         | 
| 132 | 
         
             
                    # pipe.enable_model_cpu_offload()
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 133 | 
         | 
| 134 | 
         
             
                    self.pipe = pipe
         
     | 
| 135 | 
         
             
                    self.pipe.set_progress_bar_config(disable=True)
         
     | 
| 136 | 
         
            -
             
     | 
| 137 | 
         
             
                    #     vae = AutoencoderKL.from_pretrained(
         
     | 
| 138 | 
         
             
                    #         base_model_path, subfolder="vae", torch_dtype=torch_dtype
         
     | 
| 139 | 
         
             
                    # )
         
     | 
| 
         | 
|
| 2 | 
         | 
| 3 | 
         
             
            from optimum.quanto import freeze, qfloat8, quantize
         
     | 
| 4 | 
         
             
            from transformers.modeling_utils import PreTrainedModel
         
     | 
| 5 | 
         
            +
            from diffusers import AutoencoderTiny
         
     | 
| 6 | 
         
            +
            from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
         
     | 
| 7 | 
         
            +
            from diffusers.pipelines.flux.pipeline_flux_img2img import FluxImg2ImgPipeline
         
     | 
| 8 | 
         
            +
            from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
         
     | 
| 9 | 
         
            +
            from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            from pruna import smash, SmashConfig
         
     | 
| 13 | 
         
            +
            from pruna.telemetry import set_telemetry_metrics
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            set_telemetry_metrics(False)  # disable telemetry for current session
         
     | 
| 16 | 
         
            +
            set_telemetry_metrics(False, set_as_default=True)  # disable telemetry globally
         
     | 
| 17 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 18 | 
         | 
| 19 | 
         
             
            try:
         
     | 
| 20 | 
         
             
                import intel_extension_for_pytorch as ipex  # type: ignore
         
     | 
| 
         | 
|
| 74 | 
         
             
                        1, min=1, max=15, title="Steps", field="range", hide=True, id="steps"
         
     | 
| 75 | 
         
             
                    )
         
     | 
| 76 | 
         
             
                    width: int = Field(
         
     | 
| 77 | 
         
            +
                        1024, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
         
     | 
| 78 | 
         
             
                    )
         
     | 
| 79 | 
         
             
                    height: int = Field(
         
     | 
| 80 | 
         
            +
                        1024, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
         
     | 
| 81 | 
         
             
                    )
         
     | 
| 82 | 
         
             
                    strength: float = Field(
         
     | 
| 83 | 
         
             
                        0.5,
         
     | 
| 
         | 
|
| 105 | 
         
             
                    #     "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
         
     | 
| 106 | 
         
             
                    # )
         
     | 
| 107 | 
         
             
                    print("Loading model")
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                    model_id = "black-forest-labs/FLUX.1-schnell"
         
     | 
| 110 | 
         
            +
                    model_revision = "refs/pr/1"
         
     | 
| 111 | 
         
            +
                    text_model_id = "openai/clip-vit-large-patch14"
         
     | 
| 112 | 
         
            +
                    model_data_type = torch.bfloat16
         
     | 
| 113 | 
         
            +
                    tokenizer = CLIPTokenizer.from_pretrained(
         
     | 
| 114 | 
         
            +
                        text_model_id, torch_dtype=model_data_type
         
     | 
| 115 | 
         
            +
                    )
         
     | 
| 116 | 
         
            +
                    text_encoder = CLIPTextModel.from_pretrained(
         
     | 
| 117 | 
         
            +
                        text_model_id, torch_dtype=model_data_type
         
     | 
| 118 | 
         
            +
                    )
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                    # 2
         
     | 
| 121 | 
         
            +
                    tokenizer_2 = T5TokenizerFast.from_pretrained(
         
     | 
| 122 | 
         
            +
                        model_id,
         
     | 
| 123 | 
         
            +
                        subfolder="tokenizer_2",
         
     | 
| 124 | 
         
            +
                        torch_dtype=model_data_type,
         
     | 
| 125 | 
         
            +
                        revision=model_revision,
         
     | 
| 126 | 
         
            +
                    )
         
     | 
| 127 | 
         
            +
                    text_encoder_2 = T5EncoderModel.from_pretrained(
         
     | 
| 128 | 
         
            +
                        model_id,
         
     | 
| 129 | 
         
            +
                        subfolder="text_encoder_2",
         
     | 
| 130 | 
         
            +
                        torch_dtype=model_data_type,
         
     | 
| 131 | 
         
            +
                        revision=model_revision,
         
     | 
| 132 | 
         
            +
                    )
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                    # Transformers
         
     | 
| 135 | 
         
            +
                    scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
         
     | 
| 136 | 
         
            +
                        model_id, subfolder="scheduler", revision=model_revision
         
     | 
| 137 | 
         
            +
                    )
         
     | 
| 138 | 
         
            +
                    transformer = FluxTransformer2DModel.from_pretrained(
         
     | 
| 139 | 
         
            +
                        model_id,
         
     | 
| 140 | 
         
            +
                        subfolder="transformer",
         
     | 
| 141 | 
         
            +
                        torch_dtype=model_data_type,
         
     | 
| 142 | 
         
            +
                        revision=model_revision,
         
     | 
| 143 | 
         
            +
                    )
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                    # VAE
         
     | 
| 146 | 
         
            +
                    # vae = AutoencoderKL.from_pretrained(
         
     | 
| 147 | 
         
            +
                    #     model_id,
         
     | 
| 148 | 
         
            +
                    #     subfolder="vae",
         
     | 
| 149 | 
         
            +
                    #     torch_dtype=model_data_type,
         
     | 
| 150 | 
         
            +
                    #     revision=model_revision,
         
     | 
| 151 | 
         
            +
                    # )
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                    vae = AutoencoderTiny.from_pretrained(
         
     | 
| 154 | 
         
            +
                        "madebyollin/taef1", torch_dtype=torch.bfloat16
         
     | 
| 155 | 
         
            +
                    )
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                    # Initialize the SmashConfig
         
     | 
| 158 | 
         
            +
                    smash_config = SmashConfig()
         
     | 
| 159 | 
         
            +
                    smash_config["quantizer"] = "quanto"
         
     | 
| 160 | 
         
            +
                    smash_config["quanto_calibrate"] = False
         
     | 
| 161 | 
         
            +
                    smash_config["quanto_weight_bits"] = "qint4"
         
     | 
| 162 | 
         
            +
                    # (
         
     | 
| 163 | 
         
            +
                    #     "qint4"  # "qfloat8"  # or "qint2", "qint4", "qint8"
         
     | 
| 164 | 
         
            +
                    # )
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                    transformer = smash(
         
     | 
| 167 | 
         
            +
                        model=transformer,
         
     | 
| 168 | 
         
            +
                        smash_config=smash_config,
         
     | 
| 169 | 
         
            +
                    )
         
     | 
| 170 | 
         
            +
                    text_encoder_2 = smash(
         
     | 
| 171 | 
         
            +
                        model=text_encoder_2,
         
     | 
| 172 | 
         
            +
                        smash_config=smash_config,
         
     | 
| 173 | 
         
            +
                    )
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                    pipe = FluxImg2ImgPipeline(
         
     | 
| 176 | 
         
            +
                        scheduler=scheduler,
         
     | 
| 177 | 
         
            +
                        text_encoder=text_encoder,
         
     | 
| 178 | 
         
            +
                        tokenizer=tokenizer,
         
     | 
| 179 | 
         
            +
                        text_encoder_2=text_encoder_2,
         
     | 
| 180 | 
         
            +
                        tokenizer_2=tokenizer_2,
         
     | 
| 181 | 
         
            +
                        vae=vae,
         
     | 
| 182 | 
         
             
                        transformer=transformer,
         
     | 
| 
         | 
|
| 183 | 
         
             
                    )
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                    # if args.taesd:
         
     | 
| 186 | 
         
            +
                    #     pipe.vae = AutoencoderTiny.from_pretrained(
         
     | 
| 187 | 
         
            +
                    #         taesd_path, torch_dtype=torch.bfloat16, use_safetensors=True
         
     | 
| 188 | 
         
            +
                    #     )
         
     | 
| 189 | 
         
             
                    # pipe.enable_model_cpu_offload()
         
     | 
| 190 | 
         
            +
                    pipe.text_encoder.to(device)
         
     | 
| 191 | 
         
            +
                    pipe.vae.to(device)
         
     | 
| 192 | 
         
            +
                    pipe.transformer.to(device)
         
     | 
| 193 | 
         
            +
                    pipe.text_encoder_2.to(device)
         
     | 
| 194 | 
         | 
| 195 | 
         
             
                    # pipe.enable_model_cpu_offload()
         
     | 
| 196 | 
         
            +
                    # For added memory savings run this block, there is however a trade-off with speed.
         
     | 
| 197 | 
         
            +
                    # vae.enable_tiling()
         
     | 
| 198 | 
         
            +
                    # vae.enable_slicing()
         
     | 
| 199 | 
         
            +
                    # pipe.enable_sequential_cpu_offload()
         
     | 
| 200 | 
         | 
| 201 | 
         
             
                    self.pipe = pipe
         
     | 
| 202 | 
         
             
                    self.pipe.set_progress_bar_config(disable=True)
         
     | 
| 
         | 
|
| 203 | 
         
             
                    #     vae = AutoencoderKL.from_pretrained(
         
     | 
| 204 | 
         
             
                    #         base_model_path, subfolder="vae", torch_dtype=torch_dtype
         
     | 
| 205 | 
         
             
                    # )
         
     | 
    	
        server/pipelines/img2imgSDTurbo.py
    CHANGED
    
    | 
         @@ -15,6 +15,7 @@ from PIL import Image 
     | 
|
| 15 | 
         
             
            from util import ParamsModel
         
     | 
| 16 | 
         
             
            import math
         
     | 
| 17 | 
         | 
| 
         | 
|
| 18 | 
         | 
| 19 | 
         
             
            base_model = "stabilityai/sd-turbo"
         
     | 
| 20 | 
         
             
            taesd_model = "madebyollin/taesd"
         
     | 
| 
         @@ -102,6 +103,13 @@ class Pipeline: 
     | 
|
| 102 | 
         
             
                            taesd_model, torch_dtype=torch_dtype, use_safetensors=True
         
     | 
| 103 | 
         
             
                        ).to(device)
         
     | 
| 104 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 105 | 
         
             
                    if args.sfast:
         
     | 
| 106 | 
         
             
                        from sfast.compilers.stable_diffusion_pipeline_compiler import (
         
     | 
| 107 | 
         
             
                            compile,
         
     | 
| 
         @@ -130,8 +138,8 @@ class Pipeline: 
     | 
|
| 130 | 
         | 
| 131 | 
         
             
                    self.pipe.set_progress_bar_config(disable=True)
         
     | 
| 132 | 
         
             
                    self.pipe.to(device=device, dtype=torch_dtype)
         
     | 
| 133 | 
         
            -
                    if device.type != "mps":
         
     | 
| 134 | 
         
            -
             
     | 
| 135 | 
         | 
| 136 | 
         
             
                    if args.torch_compile:
         
     | 
| 137 | 
         
             
                        print("Running torch compile")
         
     | 
| 
         | 
|
| 15 | 
         
             
            from util import ParamsModel
         
     | 
| 16 | 
         
             
            import math
         
     | 
| 17 | 
         | 
| 18 | 
         
            +
            from pruna import smash, SmashConfig
         
     | 
| 19 | 
         | 
| 20 | 
         
             
            base_model = "stabilityai/sd-turbo"
         
     | 
| 21 | 
         
             
            taesd_model = "madebyollin/taesd"
         
     | 
| 
         | 
|
| 103 | 
         
             
                            taesd_model, torch_dtype=torch_dtype, use_safetensors=True
         
     | 
| 104 | 
         
             
                        ).to(device)
         
     | 
| 105 | 
         | 
| 106 | 
         
            +
                    if args.pruna:
         
     | 
| 107 | 
         
            +
                        # Create and smash your model
         
     | 
| 108 | 
         
            +
                        smash_config = SmashConfig()
         
     | 
| 109 | 
         
            +
                        smash_config["cacher"] = "deepcache"
         
     | 
| 110 | 
         
            +
                        smash_config["compiler"] = "stable_fast"
         
     | 
| 111 | 
         
            +
                        self.pipe = smash(model=self.pipe, smash_config=smash_config)
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
             
                    if args.sfast:
         
     | 
| 114 | 
         
             
                        from sfast.compilers.stable_diffusion_pipeline_compiler import (
         
     | 
| 115 | 
         
             
                            compile,
         
     | 
| 
         | 
|
| 138 | 
         | 
| 139 | 
         
             
                    self.pipe.set_progress_bar_config(disable=True)
         
     | 
| 140 | 
         
             
                    self.pipe.to(device=device, dtype=torch_dtype)
         
     | 
| 141 | 
         
            +
                    # if device.type != "mps":
         
     | 
| 142 | 
         
            +
                    #     self.pipe.unet.to(memory_format=torch.channels_last)
         
     | 
| 143 | 
         | 
| 144 | 
         
             
                    if args.torch_compile:
         
     | 
| 145 | 
         
             
                        print("Running torch compile")
         
     | 
    	
        server/pipelines/img2imgSDXL-Lightning.py
    CHANGED
    
    | 
         @@ -20,6 +20,7 @@ from pydantic import BaseModel, Field 
     | 
|
| 20 | 
         
             
            from PIL import Image
         
     | 
| 21 | 
         
             
            from util import ParamsModel
         
     | 
| 22 | 
         
             
            import math
         
     | 
| 
         | 
|
| 23 | 
         | 
| 24 | 
         
             
            base = "stabilityai/stable-diffusion-xl-base-1.0"
         
     | 
| 25 | 
         
             
            repo = "ByteDance/SDXL-Lightning"
         
     | 
| 
         @@ -135,6 +136,13 @@ class Pipeline: 
     | 
|
| 135 | 
         
             
                        self.pipe.scheduler.config, timestep_spacing="trailing"
         
     | 
| 136 | 
         
             
                    )
         
     | 
| 137 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 138 | 
         
             
                    if args.sfast:
         
     | 
| 139 | 
         
             
                        from sfast.compilers.stable_diffusion_pipeline_compiler import (
         
     | 
| 140 | 
         
             
                            compile,
         
     | 
| 
         | 
|
| 20 | 
         
             
            from PIL import Image
         
     | 
| 21 | 
         
             
            from util import ParamsModel
         
     | 
| 22 | 
         
             
            import math
         
     | 
| 23 | 
         
            +
            from pruna import SmashConfig, smash
         
     | 
| 24 | 
         | 
| 25 | 
         
             
            base = "stabilityai/stable-diffusion-xl-base-1.0"
         
     | 
| 26 | 
         
             
            repo = "ByteDance/SDXL-Lightning"
         
     | 
| 
         | 
|
| 136 | 
         
             
                        self.pipe.scheduler.config, timestep_spacing="trailing"
         
     | 
| 137 | 
         
             
                    )
         
     | 
| 138 | 
         | 
| 139 | 
         
            +
                    if args.pruna:
         
     | 
| 140 | 
         
            +
                        # Create and smash your model
         
     | 
| 141 | 
         
            +
                        smash_config = SmashConfig()
         
     | 
| 142 | 
         
            +
                        smash_config["cacher"] = "deepcache"
         
     | 
| 143 | 
         
            +
                        smash_config["compiler"] = "stable_fast"
         
     | 
| 144 | 
         
            +
                        self.pipe = smash(model=self.pipe, smash_config=smash_config)
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
             
                    if args.sfast:
         
     | 
| 147 | 
         
             
                        from sfast.compilers.stable_diffusion_pipeline_compiler import (
         
     | 
| 148 | 
         
             
                            compile,
         
     | 
    	
        server/pipelines/img2imgSDXLTurbo.py
    CHANGED
    
    | 
         @@ -17,6 +17,13 @@ from PIL import Image 
     | 
|
| 17 | 
         
             
            from util import ParamsModel
         
     | 
| 18 | 
         
             
            import math
         
     | 
| 19 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 20 | 
         
             
            base_model = "stabilityai/sdxl-turbo"
         
     | 
| 21 | 
         
             
            taesd_model = "madebyollin/taesdxl"
         
     | 
| 22 | 
         | 
| 
         @@ -104,10 +111,11 @@ class Pipeline: 
     | 
|
| 104 | 
         
             
                    )
         
     | 
| 105 | 
         | 
| 106 | 
         
             
                def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
         
     | 
| 107 | 
         
            -
                     
     | 
| 108 | 
         
             
                        base_model,
         
     | 
| 109 | 
         
             
                        safety_checker=None,
         
     | 
| 110 | 
         
             
                    )
         
     | 
| 
         | 
|
| 111 | 
         
             
                    if args.taesd:
         
     | 
| 112 | 
         
             
                        self.pipe.vae = AutoencoderTiny.from_pretrained(
         
     | 
| 113 | 
         
             
                            taesd_model, torch_dtype=torch_dtype, use_safetensors=True
         
     | 
| 
         @@ -125,11 +133,16 @@ class Pipeline: 
     | 
|
| 125 | 
         
             
                        config.enable_cuda_graph = True
         
     | 
| 126 | 
         
             
                        self.pipe = compile(self.pipe, config=config)
         
     | 
| 127 | 
         | 
| 128 | 
         
            -
                    self.pipe.set_progress_bar_config(disable=True)
         
     | 
| 129 | 
         
            -
                    self.pipe.to(device=device, dtype=torch_dtype)
         
     | 
| 130 | 
         
             
                    if device.type != "mps":
         
     | 
| 131 | 
         
             
                        self.pipe.unet.to(memory_format=torch.channels_last)
         
     | 
| 132 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 133 | 
         
             
                    if args.torch_compile:
         
     | 
| 134 | 
         
             
                        print("Running torch compile")
         
     | 
| 135 | 
         
             
                        self.pipe.unet = torch.compile(
         
     | 
| 
         @@ -151,6 +164,9 @@ class Pipeline: 
     | 
|
| 151 | 
         
             
                            requires_pooled=[False, True],
         
     | 
| 152 | 
         
             
                        )
         
     | 
| 153 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 154 | 
         
             
                def predict(self, params: "Pipeline.InputParams") -> Image.Image:
         
     | 
| 155 | 
         
             
                    generator = torch.manual_seed(params.seed)
         
     | 
| 156 | 
         
             
                    prompt = params.prompt
         
     | 
| 
         | 
|
| 17 | 
         
             
            from util import ParamsModel
         
     | 
| 18 | 
         
             
            import math
         
     | 
| 19 | 
         | 
| 20 | 
         
            +
            from pruna import smash, SmashConfig
         
     | 
| 21 | 
         
            +
            from pruna.telemetry import set_telemetry_metrics
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            set_telemetry_metrics(False)  # disable telemetry for current session
         
     | 
| 24 | 
         
            +
            set_telemetry_metrics(False, set_as_default=True)  # disable telemetry globally
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
             
            base_model = "stabilityai/sdxl-turbo"
         
     | 
| 28 | 
         
             
            taesd_model = "madebyollin/taesdxl"
         
     | 
| 29 | 
         | 
| 
         | 
|
| 111 | 
         
             
                    )
         
     | 
| 112 | 
         | 
| 113 | 
         
             
                def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
         
     | 
| 114 | 
         
            +
                    base_pipe = AutoPipelineForImage2Image.from_pretrained(
         
     | 
| 115 | 
         
             
                        base_model,
         
     | 
| 116 | 
         
             
                        safety_checker=None,
         
     | 
| 117 | 
         
             
                    )
         
     | 
| 118 | 
         
            +
                    self.pipe = None
         
     | 
| 119 | 
         
             
                    if args.taesd:
         
     | 
| 120 | 
         
             
                        self.pipe.vae = AutoencoderTiny.from_pretrained(
         
     | 
| 121 | 
         
             
                            taesd_model, torch_dtype=torch_dtype, use_safetensors=True
         
     | 
| 
         | 
|
| 133 | 
         
             
                        config.enable_cuda_graph = True
         
     | 
| 134 | 
         
             
                        self.pipe = compile(self.pipe, config=config)
         
     | 
| 135 | 
         | 
| 
         | 
|
| 
         | 
|
| 136 | 
         
             
                    if device.type != "mps":
         
     | 
| 137 | 
         
             
                        self.pipe.unet.to(memory_format=torch.channels_last)
         
     | 
| 138 | 
         | 
| 139 | 
         
            +
                    if args.pruna:
         
     | 
| 140 | 
         
            +
                        # Create and smash your model
         
     | 
| 141 | 
         
            +
                        smash_config = SmashConfig()
         
     | 
| 142 | 
         
            +
                        smash_config["cacher"] = "deepcache"
         
     | 
| 143 | 
         
            +
                        smash_config["compiler"] = "stable_fast"
         
     | 
| 144 | 
         
            +
                        self.pipe = smash(model=base_pipe, smash_config=smash_config)
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
             
                    if args.torch_compile:
         
     | 
| 147 | 
         
             
                        print("Running torch compile")
         
     | 
| 148 | 
         
             
                        self.pipe.unet = torch.compile(
         
     | 
| 
         | 
|
| 164 | 
         
             
                            requires_pooled=[False, True],
         
     | 
| 165 | 
         
             
                        )
         
     | 
| 166 | 
         | 
| 167 | 
         
            +
                    self.pipe.set_progress_bar_config(disable=True)
         
     | 
| 168 | 
         
            +
                    self.pipe.to(device=device, dtype=torch_dtype)
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
             
                def predict(self, params: "Pipeline.InputParams") -> Image.Image:
         
     | 
| 171 | 
         
             
                    generator = torch.manual_seed(params.seed)
         
     | 
| 172 | 
         
             
                    prompt = params.prompt
         
     | 
    	
        server/requirements.txt
    CHANGED
    
    | 
         @@ -1,30 +1,35 @@ 
     | 
|
| 1 | 
         
            -
             
     | 
| 2 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 3 | 
         
             
            huggingface-hub
         
     | 
| 4 | 
         
             
            hf_transfer
         
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
             
     | 
| 7 | 
         
            -
            fastapi==0.115.6
         
     | 
| 8 | 
         
            -
            uvicorn[standard]==0.34.0
         
     | 
| 9 | 
         
             
            Pillow==11.0.0
         
     | 
| 10 | 
         
            -
            accelerate 
     | 
| 11 | 
         
             
            compel==2.0.2
         
     | 
| 12 | 
         
             
            controlnet-aux==0.0.9
         
     | 
| 13 | 
         
             
            peft==0.14.0
         
     | 
| 14 | 
         
            -
            xformers; sys_platform != 'darwin' or platform_machine != 'arm64'
         
     | 
| 15 | 
         
             
            markdown2
         
     | 
| 16 | 
         
             
            safetensors
         
     | 
| 17 | 
         
            -
            stable_fast @ https://github.com/chengzeyi/stable-fast/releases/download/nightly/stable_fast-1.0.5.dev20241127+torch230cu121-cp310-cp310-manylinux2014_x86_64.whl ; sys_platform != 'darwin' or platform_machine != 'arm64'
         
     | 
| 18 | 
         
             
            #oneflow @ https://github.com/siliconflow/oneflow_releases/releases/download/community_cu121/oneflow-0.9.1.dev20241114%2Bcu121-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl ; sys_platform != 'darwin' or platform_machine != 'arm64'
         
     | 
| 19 | 
         
             
            #onediff @ git+https://github.com/siliconflow/onediff.git@main#egg=onediff ; sys_platform != 'darwin' or platform_machine != 'arm64'
         
     | 
| 20 | 
         
             
            setuptools
         
     | 
| 21 | 
         
             
            mpmath==1.3.0
         
     | 
| 22 | 
         
            -
            numpy==1.*
         
     | 
| 23 | 
         
             
            controlnet-aux
         
     | 
| 24 | 
         
             
            sentencepiece==0.2.0
         
     | 
| 25 | 
         
            -
            optimum-quanto
         
     | 
| 26 | 
         
             
            gguf==0.13.0
         
     | 
| 27 | 
         
            -
            pydantic>=2.7.0
         
     | 
| 28 | 
         
             
            types-Pillow
         
     | 
| 29 | 
         
             
            mypy
         
     | 
| 30 | 
         
            -
            python-dotenv
         
     | 
| 
         | 
|
| 1 | 
         
            +
            --extra-index-url https://download.pytorch.org/whl/cu118
         
     | 
| 2 | 
         
            +
            torch==2.5.1
         
     | 
| 3 | 
         
            +
            torchvision
         
     | 
| 4 | 
         
            +
            torchaudio
         
     | 
| 5 | 
         
            +
            xformers; sys_platform != 'darwin' or platform_machine != 'arm64'
         
     | 
| 6 | 
         
            +
            numpy
         
     | 
| 7 | 
         
            +
            diffusers
         
     | 
| 8 | 
         
            +
            llvmlite>=0.39.0
         
     | 
| 9 | 
         
            +
            numba>=0.56.0
         
     | 
| 10 | 
         
            +
            pruna[stable-fast] ; sys_platform != 'darwin' or platform_machine != 'arm64'
         
     | 
| 11 | 
         
            +
            transformers
         
     | 
| 12 | 
         
            +
            pydantic
         
     | 
| 13 | 
         
             
            huggingface-hub
         
     | 
| 14 | 
         
             
            hf_transfer
         
     | 
| 15 | 
         
            +
            fastapi
         
     | 
| 16 | 
         
            +
            uvicorn[standard]
         
     | 
| 
         | 
|
| 
         | 
|
| 17 | 
         
             
            Pillow==11.0.0
         
     | 
| 18 | 
         
            +
            accelerate
         
     | 
| 19 | 
         
             
            compel==2.0.2
         
     | 
| 20 | 
         
             
            controlnet-aux==0.0.9
         
     | 
| 21 | 
         
             
            peft==0.14.0
         
     | 
| 
         | 
|
| 22 | 
         
             
            markdown2
         
     | 
| 23 | 
         
             
            safetensors
         
     | 
| 24 | 
         
            +
            # stable_fast @ https://github.com/chengzeyi/stable-fast/releases/download/nightly/stable_fast-1.0.5.dev20241127+torch230cu121-cp310-cp310-manylinux2014_x86_64.whl ; sys_platform != 'darwin' or platform_machine != 'arm64'
         
     | 
| 25 | 
         
             
            #oneflow @ https://github.com/siliconflow/oneflow_releases/releases/download/community_cu121/oneflow-0.9.1.dev20241114%2Bcu121-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl ; sys_platform != 'darwin' or platform_machine != 'arm64'
         
     | 
| 26 | 
         
             
            #onediff @ git+https://github.com/siliconflow/onediff.git@main#egg=onediff ; sys_platform != 'darwin' or platform_machine != 'arm64'
         
     | 
| 27 | 
         
             
            setuptools
         
     | 
| 28 | 
         
             
            mpmath==1.3.0
         
     | 
| 
         | 
|
| 29 | 
         
             
            controlnet-aux
         
     | 
| 30 | 
         
             
            sentencepiece==0.2.0
         
     | 
| 31 | 
         
            +
            optimum-quanto # has to be optimum-quanto==0.2.5 for pruna int4
         
     | 
| 32 | 
         
             
            gguf==0.13.0
         
     | 
| 
         | 
|
| 33 | 
         
             
            types-Pillow
         
     | 
| 34 | 
         
             
            mypy
         
     | 
| 35 | 
         
            +
            python-dotenv
         
     |