simswap55 / app.py
LB5's picture
Upload 45 files
22b8701
from pathlib import Path
from typing import Optional
from tqdm import tqdm
import hydra
from omegaconf import DictConfig
import numpy as np
from src.simswap import SimSwap
from src.DataManager.ImageDataManager import ImageDataManager
from src.DataManager.VideoDataManager import VideoDataManager
from src.DataManager.utils import imread_rgb
class Application:
def __init__(self, config: DictConfig):
id_image_path = Path(config.data.id_image)
specific_id_image_path = Path(config.data.specific_id_image)
att_image_path = Path(config.data.att_image)
att_video_path = Path(config.data.att_video)
output_dir = Path(config.data.output_dir)
assert id_image_path.exists(), f"Can't find {id_image_path} file!"
self.id_image: Optional[np.ndarray] = imread_rgb(id_image_path)
self.specific_id_image: Optional[np.ndarray] = (
imread_rgb(specific_id_image_path)
if specific_id_image_path and specific_id_image_path.is_file()
else None
)
self.att_image: Optional[ImageDataManager] = None
if att_image_path and (att_image_path.is_file() or att_image_path.is_dir()):
self.att_image: Optional[ImageDataManager] = ImageDataManager(
src_data=att_image_path, output_dir=output_dir
)
self.att_video: Optional[VideoDataManager] = None
if att_video_path and att_video_path.is_file():
self.att_video: Optional[VideoDataManager] = VideoDataManager(
src_data=att_video_path, output_dir=output_dir, clean_work_dir=config.data.clean_work_dir
)
assert not (self.att_video and self.att_image), "Only one attribute source can be used!"
self.data_manager = self.att_video if self.att_video else self.att_image
self.model = SimSwap(
config=config.pipeline,
id_image=self.id_image,
specific_image=self.specific_id_image,
)
def run(self):
for _ in tqdm(range(len(self.data_manager))):
att_img = self.data_manager.get()
output = self.model(att_img)
self.data_manager.save(output)
@hydra.main(config_path="configs/", config_name="run_image.yaml")
def main(config: DictConfig):
app = Application(config)
app.run()
if __name__ == "__main__":
main()