Spaces:
Runtime error
Runtime error
# Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, visit | |
# https://github.com/NVlabs/prismer/blob/main/LICENSE | |
import torch | |
import os | |
try: | |
import ruamel_yaml as yaml | |
except ModuleNotFoundError: | |
import ruamel.yaml as yaml | |
from experts.model_bank import load_expert_model | |
from experts.edge.generate_dataset import Dataset | |
from experts.edge.images import fuse_edge | |
import PIL.Image as Image | |
from accelerate import Accelerator | |
from tqdm import tqdm | |
model, transform = load_expert_model(task='edge') | |
accelerator = Accelerator(mixed_precision='fp16') | |
config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader) | |
data_path = config['data_path'] | |
save_path = os.path.join(config['save_path'], 'edge') | |
batch_size = 64 | |
dataset = Dataset(data_path, transform) | |
data_loader = torch.utils.data.DataLoader( | |
dataset=dataset, | |
batch_size=batch_size, | |
shuffle=False, | |
num_workers=4, | |
pin_memory=True | |
) | |
model, data_loader = accelerator.prepare(model, data_loader) | |
with torch.no_grad(): | |
for i, (test_data, img_path, img_size) in enumerate(tqdm(data_loader)): | |
test_pred = model(test_data) | |
fuses = fuse_edge(test_pred) | |
for k in range(len(fuses)): | |
edge = fuses[k] | |
img_path_split = img_path[k].split('/') | |
ps = img_path[k].split('.')[-1] | |
im_save_path = os.path.join(save_path, img_path_split[-3], img_path_split[-2]) | |
os.makedirs(im_save_path, exist_ok=True) | |
im_size = img_size[0][k].item(), img_size[1][k].item() | |
edge = Image.fromarray(edge).convert('L') | |
edge = edge.resize((im_size[0], im_size[1]), resample=Image.Resampling.BILINEAR) | |
edge.save(os.path.join(im_save_path, img_path_split[-1].replace(f'.{ps}', '.png'))) | |