# import argparse import sys from pathlib import Path from pytorch_lightning.cli import LightningCLI from PIL import Image # For streaming import yaml from copy import deepcopy from typing import List, Optional from jsonargparse.typing import restricted_string_type # -------------------------------------- # ----------- For Streaming ------------ # -------------------------------------- class CustomCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.add_argument("--result_fol", type=Path, help="Set the path to the result folder", default="results") parser.add_argument("--exp_name", type=str, help="Experiment name") parser.add_argument("--run_name", type=str, help="Current run name") parser.add_argument("--prompts", type=Optional[List[str]]) parser.add_argument("--scale_lr", type=bool, help="Scale lr", default=False) CodeType = restricted_string_type( 'CodeType', '(medium)|(high)|(highest)') parser.add_argument("--matmul_precision", type=CodeType) parser.add_argument("--ckpt", type=Path,) parser.add_argument("--n_predictions", type=int) return parser def remove_value(dictionary, x): for key, value in list(dictionary.items()): if key == x: del dictionary[key] elif isinstance(value, dict): remove_value(value, x) return dictionary def legacy_transformation(cfg: yaml): cfg = deepcopy(cfg) cfg["trainer"]["devices"] = "1" cfg["trainer"]['num_nodes'] = 1 if not "class_path" in cfg["model"]["inference_params"]: cfg["model"]["inference_params"] = { "class_path": "t2v_enhanced.model.pl_module_params.InferenceParams", "init_args": cfg["model"]["inference_params"]} return cfg # --------------------------------------------- # ----------- For enhancement ----------- # --------------------------------------------- def add_margin(pil_img, top, right, bottom, left, color): width, height = pil_img.size new_width = width + right + left new_height = height + top + bottom result = Image.new(pil_img.mode, (new_width, new_height), color) result.paste(pil_img, (left, top)) return result def resize_to_fit(image, size): W, H = size w, h = image.size if H / h > W / w: H_ = int(h * W / w) W_ = W else: W_ = int(w * H / h) H_ = H return image.resize((W_, H_)) def pad_to_fit(image, size): W, H = size w, h = image.size pad_h = (H - h) // 2 pad_w = (W - w) // 2 return add_margin(image, pad_h, pad_w, pad_h, pad_w, (0, 0, 0)) def resize_and_keep(pil_img): myheight = 576 hpercent = (myheight/float(pil_img.size[1])) wsize = int((float(pil_img.size[0])*float(hpercent))) pil_img = pil_img.resize((wsize, myheight)) return pil_img def center_crop(pil_img): width, height = pil_img.size new_width = 576 new_height = 576 left = (width - new_width)/2 top = (height - new_height)/2 right = (width + new_width)/2 bottom = (height + new_height)/2 # Crop the center of the image pil_img = pil_img.crop((left, top, right, bottom)) return pil_img