File size: 8,181 Bytes
9d51df0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/resolve/main/docs/python.md
from cog import BasePredictor, Input, Path
import os
from subprocess import call
from cldm.model import create_model, load_state_dict
from ldm.models.diffusion.ddim import DDIMSampler
from PIL import Image
import numpy as np
from typing import List
from utils import get_state_dict_path, download_model, model_dl_urls, annotator_dl_urls
MODEL_TYPE = "openpose"
if MODEL_TYPE == "canny":
from gradio_canny2image import process_canny
elif MODEL_TYPE == "depth":
from gradio_depth2image import process_depth
elif MODEL_TYPE == "hed":
from gradio_hed2image import process_hed
elif MODEL_TYPE == "normal":
from gradio_normal2image import process_normal
elif MODEL_TYPE == "mlsd":
from gradio_hough2image import process_mlsd
elif MODEL_TYPE == "scribble":
from gradio_scribble2image import process_scribble
elif MODEL_TYPE == "seg":
from gradio_seg2image import process_seg
elif MODEL_TYPE == "openpose":
from gradio_pose2image import process_pose
class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
self.model = create_model('./models/cldm_v15.yaml').cuda()
self.model.load_state_dict(load_state_dict(get_state_dict_path(MODEL_TYPE), location='cuda'))
self.ddim_sampler = DDIMSampler(self.model)
def predict(
self,
image: Path = Input(description="Input image"),
prompt: str = Input(description="Prompt for the model"),
num_samples: str = Input(
description="Number of samples (higher values may OOM)",
choices=['1', '4'],
default='1'
),
image_resolution: str = Input(
description="Image resolution to be generated",
choices = ['256', '512', '768'],
default='512'
),
low_threshold: int = Input(description="Canny line detection low threshold", default=100, ge=1, le=255), # only applicable when model type is 'canny'
high_threshold: int = Input(description="Canny line detection high threshold", default=200, ge=1, le=255), # only applicable when model type is 'canny'
ddim_steps: int = Input(description="Steps", default=20),
scale: float = Input(description="Scale for classifier-free guidance", default=9.0, ge=0.1, le=30.0),
seed: int = Input(description="Seed", default=None),
eta: float = Input(description="Controls the amount of noise that is added to the input data during the denoising diffusion process. Higher value -> more noise", default=0.0),
a_prompt: str = Input(description="Additional text to be appended to prompt", default="best quality, extremely detailed"),
n_prompt: str = Input(description="Negative Prompt", default="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"),
detect_resolution: int = Input(description="Resolution at which detection method will be applied)", default=512, ge=128, le=1024), # only applicable when model type is 'HED', 'seg', or 'MLSD'
# bg_threshold: float = Input(description="Background Threshold (only applicable when model type is 'normal')", default=0.0, ge=0.0, le=1.0), # only applicable when model type is 'normal'
# value_threshold: float = Input(description="Value Threshold (only applicable when model type is 'MLSD')", default=0.1, ge=0.01, le=2.0), # only applicable when model type is 'MLSD'
# distance_threshold: float = Input(description="Distance Threshold (only applicable when model type is 'MLSD')", default=0.1, ge=0.01, le=20.0), # only applicable when model type is 'MLSD'
) -> List[Path]:
"""Run a single prediction on the model"""
num_samples = int(num_samples)
image_resolution = int(image_resolution)
if not seed:
seed = np.random.randint(1000000)
else:
seed = int(seed)
# load input_image
input_image = Image.open(image)
# convert to numpy
input_image = np.array(input_image)
if MODEL_TYPE == "canny":
outputs = process_canny(
input_image,
prompt,
a_prompt,
n_prompt,
num_samples,
image_resolution,
ddim_steps,
scale,
seed,
eta,
low_threshold,
high_threshold,
self.model,
self.ddim_sampler,
)
elif MODEL_TYPE == "depth":
outputs = process_depth(
input_image,
prompt,
a_prompt,
n_prompt,
num_samples,
image_resolution,
detect_resolution,
ddim_steps,
scale,
seed,
eta,
self.model,
self.ddim_sampler,
)
elif MODEL_TYPE == "hed":
outputs = process_hed(
input_image,
prompt,
a_prompt,
n_prompt,
num_samples,
image_resolution,
detect_resolution,
ddim_steps,
scale,
seed,
eta,
self.model,
self.ddim_sampler,
)
elif MODEL_TYPE == "normal":
outputs = process_normal(
input_image,
prompt,
a_prompt,
n_prompt,
num_samples,
image_resolution,
ddim_steps,
scale,
seed,
eta,
bg_threshold,
self.model,
self.ddim_sampler,
)
elif MODEL_TYPE == "mlsd":
outputs = process_mlsd(
input_image,
prompt,
a_prompt,
n_prompt,
num_samples,
image_resolution,
detect_resolution,
ddim_steps,
scale,
seed,
eta,
value_threshold,
distance_threshold,
self.model,
self.ddim_sampler,
)
elif MODEL_TYPE == "scribble":
outputs = process_scribble(
input_image,
prompt,
a_prompt,
n_prompt,
num_samples,
image_resolution,
ddim_steps,
scale,
seed,
eta,
self.model,
self.ddim_sampler,
)
elif MODEL_TYPE == "seg":
outputs = process_seg(
input_image,
prompt,
a_prompt,
n_prompt,
num_samples,
image_resolution,
detect_resolution,
ddim_steps,
scale,
seed,
eta,
self.model,
self.ddim_sampler,
)
elif MODEL_TYPE == "openpose":
outputs = process_pose(
input_image,
prompt,
a_prompt,
n_prompt,
num_samples,
image_resolution,
detect_resolution,
ddim_steps,
scale,
seed,
eta,
self.model,
self.ddim_sampler,
)
# outputs from list to PIL
outputs = [Image.fromarray(output) for output in outputs]
# save outputs to file
outputs = [output.save(f"tmp/output_{i}.png") for i, output in enumerate(outputs)]
# return paths to output files
return [Path(f"tmp/output_{i}.png") for i in range(len(outputs))]
|