prismer / prismer /experts /generate_edge.py
shikunl's picture
Reset again!
b734d92
raw history blame
No virus
1.91 kB
# Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://github.com/NVlabs/prismer/blob/main/LICENSE
import torch
import os
try:
import ruamel_yaml as yaml
except ModuleNotFoundError:
import ruamel.yaml as yaml
from experts.model_bank import load_expert_model
from experts.edge.generate_dataset import Dataset
from experts.edge.images import fuse_edge
import PIL.Image as Image
from accelerate import Accelerator
from tqdm import tqdm
model, transform = load_expert_model(task='edge')
accelerator = Accelerator(mixed_precision='fp16')
config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
data_path = config['data_path']
save_path = os.path.join(config['save_path'], 'edge')
batch_size = 64
dataset = Dataset(data_path, transform)
data_loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=False,
num_workers=4,
pin_memory=True
)
model, data_loader = accelerator.prepare(model, data_loader)
with torch.no_grad():
for i, (test_data, img_path, img_size) in enumerate(tqdm(data_loader)):
test_pred = model(test_data)
fuses = fuse_edge(test_pred)
for k in range(len(fuses)):
edge = fuses[k]
img_path_split = img_path[k].split('/')
ps = img_path[k].split('.')[-1]
im_save_path = os.path.join(save_path, img_path_split[-3], img_path_split[-2])
os.makedirs(im_save_path, exist_ok=True)
im_size = img_size[0][k].item(), img_size[1][k].item()
edge = Image.fromarray(edge).convert('L')
edge = edge.resize((im_size[0], im_size[1]), resample=Image.Resampling.BILINEAR)
edge.save(os.path.join(im_save_path, img_path_split[-1].replace(f'.{ps}', '.png')))