|
|
|
|
|
|
|
from cog import BasePredictor, Input, Path |
|
import os |
|
import time |
|
import torch |
|
import shutil |
|
import subprocess |
|
import numpy as np |
|
from typing import List |
|
from diffusers.utils import load_image |
|
from transformers import CLIPImageProcessor |
|
from diffusers import ( |
|
StableDiffusionXLPipeline, |
|
StableDiffusionXLImg2ImgPipeline, |
|
StableDiffusionXLInpaintPipeline, |
|
DDIMScheduler, |
|
DPMSolverMultistepScheduler, |
|
EulerAncestralDiscreteScheduler, |
|
EulerDiscreteScheduler, |
|
HeunDiscreteScheduler, |
|
PNDMScheduler, |
|
KDPM2AncestralDiscreteScheduler, |
|
AutoencoderKL |
|
) |
|
from diffusers.pipelines.stable_diffusion.safety_checker import ( |
|
StableDiffusionSafetyChecker, |
|
) |
|
|
|
MODEL_NAME = "dataautogpt3/ProteusV0.4-Lightning" |
|
MODEL_CACHE = "checkpoints" |
|
SAFETY_CACHE = "safety-cache" |
|
FEATURE_EXTRACTOR = "feature-extractor" |
|
SAFETY_URL = "https://weights.replicate.delivery/default/sdxl/safety-1.0.tar" |
|
MODEL_URL = "https://weights.replicate.delivery/default/dataautogpt3/proteusv0.4-lightning.tar" |
|
|
|
class KarrasDPM: |
|
def from_config(config): |
|
return DPMSolverMultistepScheduler.from_config(config, use_karras_sigmas=True) |
|
|
|
SCHEDULERS = { |
|
"DDIM": DDIMScheduler, |
|
"DPMSolverMultistep": DPMSolverMultistepScheduler, |
|
"HeunDiscrete": HeunDiscreteScheduler, |
|
"KarrasDPM": KarrasDPM, |
|
"K_EULER_ANCESTRAL": EulerAncestralDiscreteScheduler, |
|
"K_EULER": EulerDiscreteScheduler, |
|
"PNDM": PNDMScheduler, |
|
"DPM++2MSDE": KDPM2AncestralDiscreteScheduler, |
|
} |
|
|
|
def download_weights(url, dest): |
|
start = time.time() |
|
print("downloading url: ", url) |
|
print("downloading to: ", dest) |
|
subprocess.check_call(["pget", "-x", url, dest], close_fds=False) |
|
print("downloading took: ", time.time() - start) |
|
|
|
class Predictor(BasePredictor): |
|
def setup(self) -> None: |
|
"""Load the model into memory to make running multiple predictions efficient""" |
|
start = time.time() |
|
print("Loading safety checker...") |
|
if not os.path.exists(SAFETY_CACHE): |
|
download_weights(SAFETY_URL, SAFETY_CACHE) |
|
print("Loading model") |
|
if not os.path.exists(MODEL_CACHE): |
|
download_weights(MODEL_URL, MODEL_CACHE) |
|
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained( |
|
SAFETY_CACHE, torch_dtype=torch.float16 |
|
).to("cuda") |
|
self.feature_extractor = CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR) |
|
print("Loading vae") |
|
vae = AutoencoderKL.from_pretrained( |
|
"madebyollin/sdxl-vae-fp16-fix", |
|
torch_dtype=torch.float16 |
|
) |
|
print("Loading txt2img pipeline...") |
|
self.txt2img_pipe = StableDiffusionXLPipeline.from_pretrained( |
|
MODEL_NAME, |
|
vae=vae, |
|
torch_dtype=torch.float16, |
|
cache_dir=MODEL_CACHE, |
|
).to('cuda') |
|
print("Loading SDXL img2img pipeline...") |
|
self.img2img_pipe = StableDiffusionXLImg2ImgPipeline( |
|
vae=self.txt2img_pipe.vae, |
|
text_encoder=self.txt2img_pipe.text_encoder, |
|
text_encoder_2=self.txt2img_pipe.text_encoder_2, |
|
tokenizer=self.txt2img_pipe.tokenizer, |
|
tokenizer_2=self.txt2img_pipe.tokenizer_2, |
|
unet=self.txt2img_pipe.unet, |
|
scheduler=self.txt2img_pipe.scheduler, |
|
).to("cuda") |
|
print("Loading SDXL inpaint pipeline...") |
|
self.inpaint_pipe = StableDiffusionXLInpaintPipeline( |
|
vae=self.txt2img_pipe.vae, |
|
text_encoder=self.txt2img_pipe.text_encoder, |
|
text_encoder_2=self.txt2img_pipe.text_encoder_2, |
|
tokenizer=self.txt2img_pipe.tokenizer, |
|
tokenizer_2=self.txt2img_pipe.tokenizer_2, |
|
unet=self.txt2img_pipe.unet, |
|
scheduler=self.txt2img_pipe.scheduler, |
|
).to("cuda") |
|
print("setup took: ", time.time() - start) |
|
|
|
def load_image(self, path): |
|
shutil.copyfile(path, "/tmp/image.png") |
|
return load_image("/tmp/image.png").convert("RGB") |
|
|
|
def run_safety_checker(self, image): |
|
safety_checker_input = self.feature_extractor(image, return_tensors="pt").to( |
|
"cuda" |
|
) |
|
np_image = [np.array(val) for val in image] |
|
image, has_nsfw_concept = self.safety_checker( |
|
images=np_image, |
|
clip_input=safety_checker_input.pixel_values.to(torch.float16), |
|
) |
|
return image, has_nsfw_concept |
|
|
|
@torch.inference_mode() |
|
def predict( |
|
self, |
|
prompt: str = Input( |
|
description="Input prompt", |
|
default="3 fish in a fish tank wearing adorable outfits, best quality, hd" |
|
), |
|
negative_prompt: str = Input( |
|
description="Negative Input prompt", |
|
default="nsfw, bad quality, bad anatomy, worst quality, low quality, low resolutions, extra fingers, blur, blurry, ugly, wrongs proportions, watermark, image artifacts, lowres, ugly, jpeg artifacts, deformed, noisy image" |
|
), |
|
image: Path = Input( |
|
description="Input image for img2img or inpaint mode", |
|
default=None, |
|
), |
|
mask: Path = Input( |
|
description="Input mask for inpaint mode. Black areas will be preserved, white areas will be inpainted.", |
|
default=None, |
|
), |
|
width: int = Input( |
|
description="Width of output image. Recommended 1024 or 1280", |
|
default=1024 |
|
), |
|
height: int = Input( |
|
description="Height of output image. Recommended 1024 or 1280", |
|
default=1024 |
|
), |
|
num_outputs: int = Input( |
|
description="Number of images to output.", |
|
ge=1, |
|
le=4, |
|
default=1, |
|
), |
|
scheduler: str = Input( |
|
description="scheduler", |
|
choices=SCHEDULERS.keys(), |
|
default="K_EULER_ANCESTRAL", |
|
), |
|
num_inference_steps: int = Input( |
|
description="Number of denoising steps", ge=1, le=100, default=8 |
|
), |
|
guidance_scale: float = Input( |
|
description="Scale for classifier-free guidance", ge=0, le=10, default=2 |
|
), |
|
prompt_strength: float = Input( |
|
description="Prompt strength when using img2img / inpaint. 1.0 corresponds to full destruction of information in image", |
|
ge=0.0, |
|
le=1.0, |
|
default=0.8, |
|
), |
|
seed: int = Input( |
|
description="Random seed. Leave blank to randomize the seed", default=None |
|
), |
|
apply_watermark: bool = Input( |
|
description="Applies a watermark to enable determining if an image is generated in downstream applications. If you have other provisions for generating or deploying images safely, you can use this to disable watermarking.", |
|
default=True, |
|
), |
|
disable_safety_checker: bool = Input( |
|
description="Disable safety checker for generated images. This feature is only available through the API. See https://replicate.com/docs/how-does-replicate-work#safety", |
|
default=False |
|
) |
|
) -> List[Path]: |
|
"""Run a single prediction on the model""" |
|
if seed is None: |
|
seed = int.from_bytes(os.urandom(4), "big") |
|
print(f"Using seed: {seed}") |
|
generator = torch.Generator("cuda").manual_seed(seed) |
|
|
|
sdxl_kwargs = {} |
|
print(f"Prompt: {prompt}") |
|
if image and mask: |
|
print("inpainting mode") |
|
sdxl_kwargs["image"] = self.load_image(image) |
|
sdxl_kwargs["mask_image"] = self.load_image(mask) |
|
sdxl_kwargs["strength"] = prompt_strength |
|
sdxl_kwargs["width"] = width |
|
sdxl_kwargs["height"] = height |
|
pipe = self.inpaint_pipe |
|
elif image: |
|
print("img2img mode") |
|
sdxl_kwargs["image"] = self.load_image(image) |
|
sdxl_kwargs["strength"] = prompt_strength |
|
pipe = self.img2img_pipe |
|
else: |
|
print("txt2img mode") |
|
sdxl_kwargs["width"] = width |
|
sdxl_kwargs["height"] = height |
|
pipe = self.txt2img_pipe |
|
|
|
|
|
if not apply_watermark: |
|
watermark_cache = pipe.watermark |
|
pipe.watermark = None |
|
|
|
pipe.scheduler = SCHEDULERS[scheduler].from_config(pipe.scheduler.config) |
|
|
|
common_args = { |
|
"prompt": [prompt] * num_outputs, |
|
"negative_prompt": [negative_prompt] * num_outputs, |
|
"guidance_scale": guidance_scale, |
|
"generator": generator, |
|
"num_inference_steps": num_inference_steps, |
|
} |
|
|
|
output = pipe(**common_args, **sdxl_kwargs) |
|
|
|
if not apply_watermark: |
|
pipe.watermark = watermark_cache |
|
|
|
if not disable_safety_checker: |
|
_, has_nsfw_content = self.run_safety_checker(output.images) |
|
|
|
output_paths = [] |
|
for i, image in enumerate(output.images): |
|
if not disable_safety_checker: |
|
if has_nsfw_content[i]: |
|
print(f"NSFW content detected in image {i}") |
|
continue |
|
output_path = f"/tmp/out-{i}.png" |
|
image.save(output_path) |
|
output_paths.append(Path(output_path)) |
|
|
|
if len(output_paths) == 0: |
|
raise Exception( |
|
f"NSFW content detected. Try running it again, or try a different prompt." |
|
) |
|
|
|
return output_paths |