Spaces:
Runtime error
Runtime error
# Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py | |
import argparse | |
import os | |
import imageio | |
import numpy as np | |
import torch | |
from einops import rearrange | |
from omegaconf import OmegaConf | |
from PIL import Image | |
from animatediff.pipelines import I2VPipeline | |
N_PROMPT = 'worst quality,low quality' | |
GUIDANCE_SCALE = 7 | |
BASE_CFG = './example/config/base.yaml' | |
I2V_MODEL = './models/PIA/pia.ckpt' | |
BASE_MODEL = './models/StableDiffusion/sd15' | |
DREAMBOOTH_PATH = './models/DreamBooth_LoRA/Counterfeit-V3.0_fp32.safetensors' | |
def post_process(videos: torch.Tensor): | |
videos = rearrange(videos[0], "c t h w -> t h w c") | |
videos = (videos * 255).clip(0, 255).cpu().numpy().astype(np.uint8) | |
return videos | |
def seed_everything(seed): | |
import random | |
import numpy as np | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
np.random.seed(seed % (2**32)) | |
random.seed(seed) | |
def preprocess_img(img_path): | |
ori_image = Image.open(img_path).convert('RGB') | |
width, height = ori_image.size | |
long_edge = max(width, height) | |
if long_edge > 512: | |
scale_factor = 512 / long_edge | |
else: | |
scale_factor = 1 | |
width = int(width * scale_factor) | |
height = int(height * scale_factor) | |
ori_image = ori_image.resize((width, height)) | |
if (width % 8 != 0) or (height % 8 != 0): | |
in_width = (width // 8) * 8 | |
in_height = (height // 8) * 8 | |
else: | |
in_width = width | |
in_height = height | |
in_image = ori_image | |
in_image = ori_image.resize((in_width, in_height)) | |
in_image_np = np.array(in_image) | |
return in_image_np, in_height, in_width | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--img', type=str) | |
parser.add_argument("--config", type=str) | |
parser.add_argument('--prompt', type=str) | |
parser.add_argument('--save-name', type=str) | |
parser.add_argument('--motion', type=int, default=2) | |
parser.add_argument('--ip-scale', type=float, default=0.3) | |
parser.add_argument('--strength', type=float, default=1) | |
args = parser.parse_args() | |
# prepare paths and pipeline | |
if args.config: | |
config = OmegaConf.load(args.config) | |
print('Load DreamBooth, LoRA and other things from config:') | |
print(config) | |
else: | |
config = dict() | |
base_model_path = BASE_MODEL | |
unet_path = I2V_MODEL | |
dreambooth_path = config.get('dreambooth', DREAMBOOTH_PATH) | |
vae_path = config.get('vae', None) | |
lora_path = config.get('lora', None) | |
lora_alpha = config.get('lora_alpha', 0) | |
only_load_vae_decoder = config.get('only_load_vae_decoder', False) | |
only_load_vae_encoder = config.get('only_load_vae_encoder', False) | |
st_motion = config.get('st_motion', None) | |
base_cfg = OmegaConf.load(BASE_CFG) | |
validation_pipeline = I2VPipeline.build_pipeline( | |
base_cfg, | |
base_model_path, | |
unet_path, | |
dreambooth_path, | |
lora_path, | |
lora_alpha, | |
vae_path, | |
ip_adapter_path='./models/IP_Adapter/', | |
ip_adapter_scale=args.ip_scale, | |
only_load_vae_decoder=only_load_vae_decoder, | |
only_load_vae_encoder=only_load_vae_encoder) | |
print(f'using unet : {unet_path}') | |
print(f'using DreamBooth: {dreambooth_path}') | |
print(f'using Lora : {lora_path}') | |
validation_pipeline.set_st_motion(st_motion) | |
print(f'Set Style Transfer Motion: {validation_pipeline.st_motion}.') | |
# load image | |
image_in, height, width = preprocess_img(args.img) | |
if config.get('suffix', None): | |
prompt = config.suffix + ',' + args.prompt | |
else: | |
prompt = args.prompt | |
sample = validation_pipeline( | |
image=image_in, | |
prompt=prompt, | |
height=height, | |
width=width, | |
video_length=16, | |
num_inference_steps=25, | |
mask_sim_template_idx=args.motion, | |
negative_prompt=config.get('n_prompt', N_PROMPT), | |
guidance_scale=config.get('guidance_scale', GUIDANCE_SCALE), | |
ip_adapter_scale=args.ip_scale, | |
strength=args.strength | |
).videos | |
save_name = args.save_name | |
parent_name = os.path.dirname(save_name) | |
if parent_name: | |
os.makedirs(parent_name, exist_ok=True) | |
imageio.mimsave(save_name, post_process(sample)) | |
print(" <<< Test Done <<<") | |