Spaces:
Configuration error
Configuration error
from pathlib import Path | |
from typing import Optional | |
from tqdm import tqdm | |
import hydra | |
from omegaconf import DictConfig | |
import numpy as np | |
from src.simswap import SimSwap | |
from src.DataManager.ImageDataManager import ImageDataManager | |
from src.DataManager.VideoDataManager import VideoDataManager | |
from src.DataManager.utils import imread_rgb | |
class Application: | |
def __init__(self, config: DictConfig): | |
id_image_path = Path(config.data.id_image) | |
specific_id_image_path = Path(config.data.specific_id_image) | |
att_image_path = Path(config.data.att_image) | |
att_video_path = Path(config.data.att_video) | |
output_dir = Path(config.data.output_dir) | |
assert id_image_path.exists(), f"Can't find {id_image_path} file!" | |
self.id_image: Optional[np.ndarray] = imread_rgb(id_image_path) | |
self.specific_id_image: Optional[np.ndarray] = ( | |
imread_rgb(specific_id_image_path) | |
if specific_id_image_path and specific_id_image_path.is_file() | |
else None | |
) | |
self.att_image: Optional[ImageDataManager] = None | |
if att_image_path and (att_image_path.is_file() or att_image_path.is_dir()): | |
self.att_image: Optional[ImageDataManager] = ImageDataManager( | |
src_data=att_image_path, output_dir=output_dir | |
) | |
self.att_video: Optional[VideoDataManager] = None | |
if att_video_path and att_video_path.is_file(): | |
self.att_video: Optional[VideoDataManager] = VideoDataManager( | |
src_data=att_video_path, output_dir=output_dir, clean_work_dir=config.data.clean_work_dir | |
) | |
assert not (self.att_video and self.att_image), "Only one attribute source can be used!" | |
self.data_manager = self.att_video if self.att_video else self.att_image | |
self.model = SimSwap( | |
config=config.pipeline, | |
id_image=self.id_image, | |
specific_image=self.specific_id_image, | |
) | |
def run(self): | |
for _ in tqdm(range(len(self.data_manager))): | |
att_img = self.data_manager.get() | |
output = self.model(att_img) | |
self.data_manager.save(output) | |
def main(config: DictConfig): | |
app = Application(config) | |
app.run() | |
if __name__ == "__main__": | |
main() | |