StreamingT2V / t2v_enhanced /inference_utils.py
hpoghos's picture
add code
f949b3f
# import argparse
import sys
from pathlib import Path
from pytorch_lightning.cli import LightningCLI
from PIL import Image
# For streaming
import yaml
from copy import deepcopy
from typing import List, Optional
from jsonargparse.typing import restricted_string_type
# --------------------------------------
# ----------- For Streaming ------------
# --------------------------------------
class CustomCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_argument("--result_fol", type=Path,
help="Set the path to the result folder", default="results")
parser.add_argument("--exp_name", type=str, help="Experiment name")
parser.add_argument("--run_name", type=str,
help="Current run name")
parser.add_argument("--prompts", type=Optional[List[str]])
parser.add_argument("--scale_lr", type=bool,
help="Scale lr", default=False)
CodeType = restricted_string_type(
'CodeType', '(medium)|(high)|(highest)')
parser.add_argument("--matmul_precision", type=CodeType)
parser.add_argument("--ckpt", type=Path,)
parser.add_argument("--n_predictions", type=int)
return parser
def remove_value(dictionary, x):
for key, value in list(dictionary.items()):
if key == x:
del dictionary[key]
elif isinstance(value, dict):
remove_value(value, x)
return dictionary
def legacy_transformation(cfg: yaml):
cfg = deepcopy(cfg)
cfg["trainer"]["devices"] = "1"
cfg["trainer"]['num_nodes'] = 1
if not "class_path" in cfg["model"]["inference_params"]:
cfg["model"]["inference_params"] = {
"class_path": "t2v_enhanced.model.pl_module_params.InferenceParams", "init_args": cfg["model"]["inference_params"]}
return cfg
# ---------------------------------------------
# ----------- For enhancement -----------
# ---------------------------------------------
def add_margin(pil_img, top, right, bottom, left, color):
width, height = pil_img.size
new_width = width + right + left
new_height = height + top + bottom
result = Image.new(pil_img.mode, (new_width, new_height), color)
result.paste(pil_img, (left, top))
return result
def resize_to_fit(image, size):
W, H = size
w, h = image.size
if H / h > W / w:
H_ = int(h * W / w)
W_ = W
else:
W_ = int(w * H / h)
H_ = H
return image.resize((W_, H_))
def pad_to_fit(image, size):
W, H = size
w, h = image.size
pad_h = (H - h) // 2
pad_w = (W - w) // 2
return add_margin(image, pad_h, pad_w, pad_h, pad_w, (0, 0, 0))
def resize_and_keep(pil_img):
myheight = 576
hpercent = (myheight/float(pil_img.size[1]))
wsize = int((float(pil_img.size[0])*float(hpercent)))
pil_img = pil_img.resize((wsize, myheight))
return pil_img
def center_crop(pil_img):
width, height = pil_img.size
new_width = 576
new_height = 576
left = (width - new_width)/2
top = (height - new_height)/2
right = (width + new_width)/2
bottom = (height + new_height)/2
# Crop the center of the image
pil_img = pil_img.crop((left, top, right, bottom))
return pil_img