CycleGAN / detect.py
Yanguan's picture
0
58da73e
raw
history blame contribute delete
No virus
2.37 kB
"""
>>> python detect.py --dataroot ./imgs/horse.jpg --style horse2zebra
"""
import os
import sys
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(BASE_DIR)
from pathlib import Path
from data import OneDataset
from models import create_model
from options import DetectOptions
from util import tensor2im, save_image, show_image, now_time, print_info
# 定义参数
opt = DetectOptions().parse()
# 硬编码一些测试参数
opt.num_threads = 0 # 测试代码仅支持num_threads = 0
opt.batch_size = 1 # 测试代码仅支持batch_size = 1
opt.serial_batches = True # 禁用数据混洗;如果需要在随机选择的图像上得到结果,请取消对此行的注释。
opt.no_flip = True # 不翻转;如果需要在翻转的图像上得到结果,请取消对此行的注释。
opt.display_id = -1 # 没有visdom显示;测试代码将结果保存到HTML文件中。
# 加载模型
model = create_model(opt)
def detect(img=opt.dataroot, style=opt.style):
result = None
time_info = "-" * 30 + f"\n{now_time()}:start"
dataset = OneDataset(img, opt) # 加载数据
model.setup(opt, style) # 设置模型
model.eval() # 切换到评估模式
print_info(time_info)
for _, data in enumerate(dataset):
model.set_input(data) # 从数据加载器中解包数据
model.test() # 推理
visuals = model.get_current_visuals() # 获取结果图像
result = tensor2im(visuals["fake"])
time_info = f"{now_time()}:done\n" + "-" * 30
print_info(time_info)
return result
def save_detect_img(results: list, no_save_img=False):
"""保存或展示检测结果"""
# 保存 or 展示
if no_save_img:
for _, img_fake in results:
show_image(img_fake)
else:
for img_path, img_fake in results:
save_dir = Path.cwd().joinpath("results") # fake图片保存路径
img_path = Path(img_path)
img_name = img_path.stem + "_fake" + img_path.suffix
Path.mkdir(save_dir, exist_ok=True)
save_path = save_dir.joinpath(img_name)
save_image(img_fake, save_path)
print("results_path: ", save_path)
if __name__ == "__main__":
results = detect() # 推理图片
save_detect_img(results, opt.no_save_img) # 保存检测结果