File size: 2,368 Bytes
58da73e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""
>>> python detect.py --dataroot ./imgs/horse.jpg --style horse2zebra
"""

import os
import sys

BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(BASE_DIR)
from pathlib import Path

from data import OneDataset
from models import create_model
from options import DetectOptions
from util import tensor2im, save_image, show_image, now_time, print_info

# 定义参数
opt = DetectOptions().parse()
# 硬编码一些测试参数
opt.num_threads = 0  # 测试代码仅支持num_threads = 0
opt.batch_size = 1  # 测试代码仅支持batch_size = 1
opt.serial_batches = True  # 禁用数据混洗;如果需要在随机选择的图像上得到结果,请取消对此行的注释。
opt.no_flip = True  # 不翻转;如果需要在翻转的图像上得到结果,请取消对此行的注释。
opt.display_id = -1  # 没有visdom显示;测试代码将结果保存到HTML文件中。

#  加载模型
model = create_model(opt)


def detect(img=opt.dataroot, style=opt.style):
    result = None
    time_info = "-" * 30 + f"\n{now_time()}:start"
    dataset = OneDataset(img, opt)  # 加载数据
    model.setup(opt, style)  # 设置模型
    model.eval()  # 切换到评估模式
    print_info(time_info)
    for _, data in enumerate(dataset):
        model.set_input(data)  # 从数据加载器中解包数据
        model.test()  # 推理
        visuals = model.get_current_visuals()  # 获取结果图像
        result = tensor2im(visuals["fake"])
        time_info = f"{now_time()}:done\n" + "-" * 30
        print_info(time_info)
    return result


def save_detect_img(results: list, no_save_img=False):
    """保存或展示检测结果"""
    # 保存 or 展示
    if no_save_img:
        for _, img_fake in results:
            show_image(img_fake)
    else:
        for img_path, img_fake in results:
            save_dir = Path.cwd().joinpath("results")  # fake图片保存路径
            img_path = Path(img_path)
            img_name = img_path.stem + "_fake" + img_path.suffix
            Path.mkdir(save_dir, exist_ok=True)
            save_path = save_dir.joinpath(img_name)
            save_image(img_fake, save_path)
            print("results_path: ", save_path)


if __name__ == "__main__":
    results = detect()  # 推理图片
    save_detect_img(results, opt.no_save_img)  # 保存检测结果