Spaces:
Runtime error
Runtime error
File size: 6,294 Bytes
e5e86e3 fd5d390 e5e86e3 6f9700a e5e86e3 6f9700a e5e86e3 fd5d390 e5e86e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
# This file is adapted from gradio_*.py in https://github.com/lllyasviel/ControlNet/tree/f4748e3630d8141d7765e2bd9b1e348f47847707
# The original license file is LICENSE.ControlNet in this repo.
from __future__ import annotations
import gc
import pathlib
import sys
import cv2
import numpy as np
import PIL.Image
import torch
from diffusers import (ControlNetModel, DiffusionPipeline,
StableDiffusionControlNetPipeline,
UniPCMultistepScheduler)
repo_dir = pathlib.Path(__file__).parent
submodule_dir = repo_dir / 'ControlNet'
sys.path.append(submodule_dir.as_posix())
from annotator.midas import apply_midas
from annotator.uniformer import apply_uniformer
from annotator.util import HWC3, resize_image
CONTROLNET_MODEL_IDS = {
'depth': 'lllyasviel/sd-controlnet-depth',
}
def download_all_controlnet_weights() -> None:
for model_id in CONTROLNET_MODEL_IDS.values():
ControlNetModel.from_pretrained(model_id)
class Model:
def __init__(self,
base_model_id: str = 'runwayml/stable-diffusion-v1-5',
task_name: str = 'depth'):
self.device = torch.device(
'cuda:0' if torch.cuda.is_available() else 'cpu')
self.base_model_id = ''
self.task_name = ''
self.pipe = self.load_pipe(base_model_id, task_name)
def load_pipe(self, base_model_id: str, task_name) -> DiffusionPipeline:
if base_model_id == self.base_model_id and task_name == self.task_name and hasattr(
self, 'pipe'):
return self.pipe
model_id = CONTROLNET_MODEL_IDS[task_name]
controlnet = ControlNetModel.from_pretrained(model_id,
torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
base_model_id,
safety_checker=None,
controlnet=controlnet,
torch_dtype=torch.float16)
pipe.scheduler = UniPCMultistepScheduler.from_config(
pipe.scheduler.config)
pipe.enable_xformers_memory_efficient_attention()
pipe.to(self.device)
torch.cuda.empty_cache()
gc.collect()
self.base_model_id = base_model_id
self.task_name = task_name
return pipe
def set_base_model(self, base_model_id: str) -> str:
if not base_model_id or base_model_id == self.base_model_id:
return self.base_model_id
del self.pipe
torch.cuda.empty_cache()
gc.collect()
try:
self.pipe = self.load_pipe(base_model_id, self.task_name)
except Exception:
self.pipe = self.load_pipe(self.base_model_id, self.task_name)
return self.base_model_id
def load_controlnet_weight(self, task_name: str) -> None:
if task_name == self.task_name:
return
del self.pipe.controlnet
torch.cuda.empty_cache()
gc.collect()
model_id = CONTROLNET_MODEL_IDS[task_name]
controlnet = ControlNetModel.from_pretrained(model_id,
torch_dtype=torch.float16)
controlnet.to(self.device)
torch.cuda.empty_cache()
gc.collect()
self.pipe.controlnet = controlnet
self.task_name = task_name
def get_prompt(self, prompt: str, additional_prompt: str) -> str:
if not prompt:
prompt = additional_prompt
else:
prompt = f'{prompt}, {additional_prompt}'
return prompt
@torch.autocast('cuda')
def run_pipe(
self,
prompt: str,
negative_prompt: str,
control_image: PIL.Image.Image,
num_images: int,
num_steps: int,
guidance_scale: float,
seed: int,
) -> list[PIL.Image.Image]:
if seed == -1:
seed = np.random.randint(0, np.iinfo(np.int64).max)
generator = torch.Generator().manual_seed(seed)
return self.pipe(prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images,
num_inference_steps=num_steps,
generator=generator,
image=control_image).images
@staticmethod
def preprocess_depth(
input_image: np.ndarray,
image_resolution: int,
detect_resolution: int,
is_depth_image: bool,
) -> tuple[PIL.Image.Image, PIL.Image.Image]:
input_image = HWC3(input_image)
if not is_depth_image:
control_image, _ = apply_midas(
resize_image(input_image, detect_resolution))
control_image = HWC3(control_image)
image = resize_image(input_image, image_resolution)
H, W = image.shape[:2]
control_image = cv2.resize(control_image, (W, H),
interpolation=cv2.INTER_LINEAR)
else:
control_image = resize_image(input_image, image_resolution)
return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
control_image)
@torch.inference_mode()
def process_depth(
self,
input_image: np.ndarray,
prompt: str,
additional_prompt: str,
negative_prompt: str,
num_images: int,
image_resolution: int,
detect_resolution: int,
num_steps: int,
guidance_scale: float,
seed: int,
is_depth_image: bool,
) -> list[PIL.Image.Image]:
control_image, vis_control_image = self.preprocess_depth(
input_image=input_image,
image_resolution=image_resolution,
detect_resolution=detect_resolution,
is_depth_image=is_depth_image,
)
self.load_controlnet_weight('depth')
results = self.run_pipe(
prompt=self.get_prompt(prompt, additional_prompt),
negative_prompt=negative_prompt,
control_image=control_image,
num_images=num_images,
num_steps=num_steps,
guidance_scale=guidance_scale,
seed=seed,
)
return [vis_control_image] + results
|