IF / model.py
multimodalart's picture
Remove compile for stage2 and 3 (#139)
f0738d3
from __future__ import annotations
import gc
import json
import tempfile
from typing import Generator
import numpy as np
import PIL.Image
import torch
from diffusers import DiffusionPipeline, StableDiffusionUpscalePipeline
from diffusers.pipelines.deepfloyd_if import (fast27_timesteps,
smart27_timesteps,
smart50_timesteps,
smart100_timesteps,
smart185_timesteps)
from settings import (DISABLE_AUTOMATIC_CPU_OFFLOAD, DISABLE_SD_X4_UPSCALER,
HF_TOKEN, MAX_NUM_IMAGES, MAX_NUM_STEPS, MAX_SEED,
RUN_GARBAGE_COLLECTION)
class Model:
def __init__(self):
self.device = torch.device(
'cuda:0' if torch.cuda.is_available() else 'cpu')
self.pipe = None
self.super_res_1_pipe = None
self.super_res_2_pipe = None
self.watermark_image = None
if torch.cuda.is_available():
self.load_weights()
self.watermark_image = PIL.Image.fromarray(
self.pipe.watermarker.watermark_image.to(
torch.uint8).cpu().numpy(),
mode='RGBA')
def load_weights(self) -> None:
self.pipe = DiffusionPipeline.from_pretrained(
'DeepFloyd/IF-I-XL-v1.0',
torch_dtype=torch.float16,
variant='fp16',
use_safetensors=True,
use_auth_token=HF_TOKEN)
self.super_res_1_pipe = DiffusionPipeline.from_pretrained(
'DeepFloyd/IF-II-L-v1.0',
text_encoder=None,
torch_dtype=torch.float16,
variant='fp16',
use_safetensors=True,
use_auth_token=HF_TOKEN)
if not DISABLE_SD_X4_UPSCALER:
self.super_res_2_pipe = StableDiffusionUpscalePipeline.from_pretrained(
'stabilityai/stable-diffusion-x4-upscaler',
torch_dtype=torch.float16)
if DISABLE_AUTOMATIC_CPU_OFFLOAD:
self.pipe.to(self.device)
self.super_res_1_pipe.to(self.device)
self.pipe.unet.to(memory_format=torch.channels_last)
self.pipe.unet = torch.compile(self.pipe.unet, mode="reduce-overhead", fullgraph=True)
if not DISABLE_SD_X4_UPSCALER:
self.super_res_2_pipe.to(self.device)
else:
self.pipe.enable_model_cpu_offload()
self.super_res_1_pipe.enable_model_cpu_offload()
if not DISABLE_SD_X4_UPSCALER:
self.super_res_2_pipe.enable_model_cpu_offload()
def apply_watermark_to_sd_x4_upscaler_results(
self, images: list[PIL.Image.Image]) -> None:
w, h = images[0].size
stability_x4_upscaler_sample_size = 128
coef = min(h / stability_x4_upscaler_sample_size,
w / stability_x4_upscaler_sample_size)
img_h, img_w = (int(h / coef), int(w / coef)) if coef < 1 else (h, w)
S1, S2 = 1024**2, img_w * img_h
K = (S2 / S1)**0.5
watermark_size = int(K * 62)
watermark_x = img_w - int(14 * K)
watermark_y = img_h - int(14 * K)
watermark_image = self.watermark_image.copy().resize(
(watermark_size, watermark_size),
PIL.Image.Resampling.BICUBIC,
reducing_gap=None)
for image in images:
image.paste(watermark_image,
box=(
watermark_x - watermark_size,
watermark_y - watermark_size,
watermark_x,
watermark_y,
),
mask=watermark_image.split()[-1])
@staticmethod
def to_pil_images(images: torch.Tensor) -> list[PIL.Image.Image]:
images = (images / 2 + 0.5).clamp(0, 1)
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
images = np.round(images * 255).astype(np.uint8)
return [PIL.Image.fromarray(image) for image in images]
@staticmethod
def check_seed(seed: int) -> None:
if not 0 <= seed <= MAX_SEED:
raise ValueError
@staticmethod
def check_num_images(num_images: int) -> None:
if not 1 <= num_images <= MAX_NUM_IMAGES:
raise ValueError
@staticmethod
def check_num_inference_steps(num_steps: int) -> None:
if not 1 <= num_steps <= MAX_NUM_STEPS:
raise ValueError
@staticmethod
def get_custom_timesteps(name: str) -> list[int] | None:
if name == 'none':
timesteps = None
elif name == 'fast27':
timesteps = fast27_timesteps
elif name == 'smart27':
timesteps = smart27_timesteps
elif name == 'smart50':
timesteps = smart50_timesteps
elif name == 'smart100':
timesteps = smart100_timesteps
elif name == 'smart185':
timesteps = smart185_timesteps
else:
raise ValueError
return timesteps
@staticmethod
def run_garbage_collection():
gc.collect()
torch.cuda.empty_cache()
def run_stage1(
self,
prompt: str,
negative_prompt: str = '',
seed: int = 0,
num_images: int = 1,
guidance_scale_1: float = 7.0,
custom_timesteps_1: str = 'smart100',
num_inference_steps_1: int = 100,
) -> tuple[list[PIL.Image.Image], str, str]:
self.check_seed(seed)
self.check_num_images(num_images)
self.check_num_inference_steps(num_inference_steps_1)
if RUN_GARBAGE_COLLECTION:
self.run_garbage_collection()
generator = torch.Generator(device=self.device).manual_seed(seed)
prompt_embeds, negative_embeds = self.pipe.encode_prompt(
prompt=prompt, negative_prompt=negative_prompt)
timesteps = self.get_custom_timesteps(custom_timesteps_1)
images = self.pipe(prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_embeds,
num_images_per_prompt=num_images,
guidance_scale=guidance_scale_1,
timesteps=timesteps,
num_inference_steps=num_inference_steps_1,
generator=generator,
output_type='pt').images
pil_images = self.to_pil_images(images)
self.pipe.watermarker.apply_watermark(
pil_images, self.pipe.unet.config.sample_size)
stage1_params = {
'prompt': prompt,
'negative_prompt': negative_prompt,
'seed': seed,
'num_images': num_images,
'guidance_scale_1': guidance_scale_1,
'custom_timesteps_1': custom_timesteps_1,
'num_inference_steps_1': num_inference_steps_1,
}
with tempfile.NamedTemporaryFile(mode='w', delete=False) as param_file:
param_file.write(json.dumps(stage1_params))
stage1_result = {
'prompt_embeds': prompt_embeds,
'negative_embeds': negative_embeds,
'images': images,
'pil_images': pil_images,
}
with tempfile.NamedTemporaryFile(delete=False) as result_file:
torch.save(stage1_result, result_file.name)
return pil_images, param_file.name, result_file.name
def run_stage2(
self,
stage1_result_path: str,
stage2_index: int,
seed_2: int = 0,
guidance_scale_2: float = 4.0,
custom_timesteps_2: str = 'smart50',
num_inference_steps_2: int = 50,
disable_watermark: bool = False,
) -> PIL.Image.Image:
self.check_seed(seed_2)
self.check_num_inference_steps(num_inference_steps_2)
if RUN_GARBAGE_COLLECTION:
self.run_garbage_collection()
generator = torch.Generator(device=self.device).manual_seed(seed_2)
stage1_result = torch.load(stage1_result_path)
prompt_embeds = stage1_result['prompt_embeds']
negative_embeds = stage1_result['negative_embeds']
images = stage1_result['images']
images = images[[stage2_index]]
timesteps = self.get_custom_timesteps(custom_timesteps_2)
out = self.super_res_1_pipe(image=images,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_embeds,
num_images_per_prompt=1,
guidance_scale=guidance_scale_2,
timesteps=timesteps,
num_inference_steps=num_inference_steps_2,
generator=generator,
output_type='pt',
noise_level=250).images
pil_images = self.to_pil_images(out)
if disable_watermark:
return pil_images[0]
self.super_res_1_pipe.watermarker.apply_watermark(
pil_images, self.super_res_1_pipe.unet.config.sample_size)
return pil_images[0]
def run_stage3(
self,
image: PIL.Image.Image,
prompt: str = '',
negative_prompt: str = '',
seed_3: int = 0,
guidance_scale_3: float = 9.0,
num_inference_steps_3: int = 75,
) -> PIL.Image.Image:
self.check_seed(seed_3)
self.check_num_inference_steps(num_inference_steps_3)
if RUN_GARBAGE_COLLECTION:
self.run_garbage_collection()
generator = torch.Generator(device=self.device).manual_seed(seed_3)
out = self.super_res_2_pipe(image=image,
prompt=prompt,
negative_prompt=negative_prompt,
num_images_per_prompt=1,
guidance_scale=guidance_scale_3,
num_inference_steps=num_inference_steps_3,
generator=generator,
noise_level=100).images
self.apply_watermark_to_sd_x4_upscaler_results(out)
return out[0]
def run_stage2_3(
self,
stage1_result_path: str,
stage2_index: int,
seed_2: int = 0,
guidance_scale_2: float = 4.0,
custom_timesteps_2: str = 'smart50',
num_inference_steps_2: int = 50,
prompt: str = '',
negative_prompt: str = '',
seed_3: int = 0,
guidance_scale_3: float = 9.0,
num_inference_steps_3: int = 75,
) -> Generator[PIL.Image.Image]:
self.check_seed(seed_3)
self.check_num_inference_steps(num_inference_steps_3)
out_image = self.run_stage2(
stage1_result_path=stage1_result_path,
stage2_index=stage2_index,
seed_2=seed_2,
guidance_scale_2=guidance_scale_2,
custom_timesteps_2=custom_timesteps_2,
num_inference_steps_2=num_inference_steps_2,
disable_watermark=True)
temp_image = out_image.copy()
self.super_res_1_pipe.watermarker.apply_watermark(
[temp_image], self.super_res_1_pipe.unet.config.sample_size)
yield temp_image
yield self.run_stage3(image=out_image,
prompt=prompt,
negative_prompt=negative_prompt,
seed_3=seed_3,
guidance_scale_3=guidance_scale_3,
num_inference_steps_3=num_inference_steps_3)