|
""" |
|
>>> python detect.py --dataroot ./imgs/horse.jpg --style horse2zebra |
|
""" |
|
|
|
import os |
|
import sys |
|
|
|
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
sys.path.append(BASE_DIR) |
|
from pathlib import Path |
|
|
|
from data import OneDataset |
|
from models import create_model |
|
from options import DetectOptions |
|
from util import tensor2im, save_image, show_image, now_time, print_info |
|
|
|
|
|
opt = DetectOptions().parse() |
|
|
|
opt.num_threads = 0 |
|
opt.batch_size = 1 |
|
opt.serial_batches = True |
|
opt.no_flip = True |
|
opt.display_id = -1 |
|
|
|
|
|
model = create_model(opt) |
|
|
|
|
|
def detect(img=opt.dataroot, style=opt.style): |
|
result = None |
|
time_info = "-" * 30 + f"\n{now_time()}:start" |
|
dataset = OneDataset(img, opt) |
|
model.setup(opt, style) |
|
model.eval() |
|
print_info(time_info) |
|
for _, data in enumerate(dataset): |
|
model.set_input(data) |
|
model.test() |
|
visuals = model.get_current_visuals() |
|
result = tensor2im(visuals["fake"]) |
|
time_info = f"{now_time()}:done\n" + "-" * 30 |
|
print_info(time_info) |
|
return result |
|
|
|
|
|
def save_detect_img(results: list, no_save_img=False): |
|
"""保存或展示检测结果""" |
|
|
|
if no_save_img: |
|
for _, img_fake in results: |
|
show_image(img_fake) |
|
else: |
|
for img_path, img_fake in results: |
|
save_dir = Path.cwd().joinpath("results") |
|
img_path = Path(img_path) |
|
img_name = img_path.stem + "_fake" + img_path.suffix |
|
Path.mkdir(save_dir, exist_ok=True) |
|
save_path = save_dir.joinpath(img_name) |
|
save_image(img_fake, save_path) |
|
print("results_path: ", save_path) |
|
|
|
|
|
if __name__ == "__main__": |
|
results = detect() |
|
save_detect_img(results, opt.no_save_img) |
|
|