PIA / inference.py
LeoXing1996
init repo for fg
a001281
raw
history blame
4.38 kB
# Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
import argparse
import os
import imageio
import numpy as np
import torch
from einops import rearrange
from omegaconf import OmegaConf
from PIL import Image
from animatediff.pipelines import I2VPipeline
N_PROMPT = 'worst quality,low quality'
GUIDANCE_SCALE = 7
BASE_CFG = './example/config/base.yaml'
I2V_MODEL = './models/PIA/pia.ckpt'
BASE_MODEL = './models/StableDiffusion/sd15'
DREAMBOOTH_PATH = './models/DreamBooth_LoRA/Counterfeit-V3.0_fp32.safetensors'
def post_process(videos: torch.Tensor):
videos = rearrange(videos[0], "c t h w -> t h w c")
videos = (videos * 255).clip(0, 255).cpu().numpy().astype(np.uint8)
return videos
def seed_everything(seed):
import random
import numpy as np
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed % (2**32))
random.seed(seed)
def preprocess_img(img_path):
ori_image = Image.open(img_path).convert('RGB')
width, height = ori_image.size
long_edge = max(width, height)
if long_edge > 512:
scale_factor = 512 / long_edge
else:
scale_factor = 1
width = int(width * scale_factor)
height = int(height * scale_factor)
ori_image = ori_image.resize((width, height))
if (width % 8 != 0) or (height % 8 != 0):
in_width = (width // 8) * 8
in_height = (height // 8) * 8
else:
in_width = width
in_height = height
in_image = ori_image
in_image = ori_image.resize((in_width, in_height))
in_image_np = np.array(in_image)
return in_image_np, in_height, in_width
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--img', type=str)
parser.add_argument("--config", type=str)
parser.add_argument('--prompt', type=str)
parser.add_argument('--save-name', type=str)
parser.add_argument('--motion', type=int, default=2)
parser.add_argument('--ip-scale', type=float, default=0.3)
parser.add_argument('--strength', type=float, default=1)
args = parser.parse_args()
# prepare paths and pipeline
if args.config:
config = OmegaConf.load(args.config)
print('Load DreamBooth, LoRA and other things from config:')
print(config)
else:
config = dict()
base_model_path = BASE_MODEL
unet_path = I2V_MODEL
dreambooth_path = config.get('dreambooth', DREAMBOOTH_PATH)
vae_path = config.get('vae', None)
lora_path = config.get('lora', None)
lora_alpha = config.get('lora_alpha', 0)
only_load_vae_decoder = config.get('only_load_vae_decoder', False)
only_load_vae_encoder = config.get('only_load_vae_encoder', False)
st_motion = config.get('st_motion', None)
base_cfg = OmegaConf.load(BASE_CFG)
validation_pipeline = I2VPipeline.build_pipeline(
base_cfg,
base_model_path,
unet_path,
dreambooth_path,
lora_path,
lora_alpha,
vae_path,
ip_adapter_path='./models/IP_Adapter/',
ip_adapter_scale=args.ip_scale,
only_load_vae_decoder=only_load_vae_decoder,
only_load_vae_encoder=only_load_vae_encoder)
print(f'using unet : {unet_path}')
print(f'using DreamBooth: {dreambooth_path}')
print(f'using Lora : {lora_path}')
validation_pipeline.set_st_motion(st_motion)
print(f'Set Style Transfer Motion: {validation_pipeline.st_motion}.')
# load image
image_in, height, width = preprocess_img(args.img)
if config.get('suffix', None):
prompt = config.suffix + ',' + args.prompt
else:
prompt = args.prompt
sample = validation_pipeline(
image=image_in,
prompt=prompt,
height=height,
width=width,
video_length=16,
num_inference_steps=25,
mask_sim_template_idx=args.motion,
negative_prompt=config.get('n_prompt', N_PROMPT),
guidance_scale=config.get('guidance_scale', GUIDANCE_SCALE),
ip_adapter_scale=args.ip_scale,
strength=args.strength
).videos
save_name = args.save_name
parent_name = os.path.dirname(save_name)
if parent_name:
os.makedirs(parent_name, exist_ok=True)
imageio.mimsave(save_name, post_process(sample))
print(" <<< Test Done <<<")