""" >>> python detect.py --dataroot ./imgs/horse.jpg --style horse2zebra """ import os import sys BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(BASE_DIR) from pathlib import Path from data import OneDataset from models import create_model from options import DetectOptions from util import tensor2im, save_image, show_image, now_time, print_info # 定义参数 opt = DetectOptions().parse() # 硬编码一些测试参数 opt.num_threads = 0 # 测试代码仅支持num_threads = 0 opt.batch_size = 1 # 测试代码仅支持batch_size = 1 opt.serial_batches = True # 禁用数据混洗;如果需要在随机选择的图像上得到结果,请取消对此行的注释。 opt.no_flip = True # 不翻转;如果需要在翻转的图像上得到结果,请取消对此行的注释。 opt.display_id = -1 # 没有visdom显示;测试代码将结果保存到HTML文件中。 # 加载模型 model = create_model(opt) def detect(img=opt.dataroot, style=opt.style): result = None time_info = "-" * 30 + f"\n{now_time()}:start" dataset = OneDataset(img, opt) # 加载数据 model.setup(opt, style) # 设置模型 model.eval() # 切换到评估模式 print_info(time_info) for _, data in enumerate(dataset): model.set_input(data) # 从数据加载器中解包数据 model.test() # 推理 visuals = model.get_current_visuals() # 获取结果图像 result = tensor2im(visuals["fake"]) time_info = f"{now_time()}:done\n" + "-" * 30 print_info(time_info) return result def save_detect_img(results: list, no_save_img=False): """保存或展示检测结果""" # 保存 or 展示 if no_save_img: for _, img_fake in results: show_image(img_fake) else: for img_path, img_fake in results: save_dir = Path.cwd().joinpath("results") # fake图片保存路径 img_path = Path(img_path) img_name = img_path.stem + "_fake" + img_path.suffix Path.mkdir(save_dir, exist_ok=True) save_path = save_dir.joinpath(img_name) save_image(img_fake, save_path) print("results_path: ", save_path) if __name__ == "__main__": results = detect() # 推理图片 save_detect_img(results, opt.no_save_img) # 保存检测结果