Spaces:
Running
Running
File size: 2,974 Bytes
a25563f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import os
import numpy as np
from sklearn.model_selection import train_test_split
import cv2
import argparse
from config import DATA_ROOT
dataset_root = os.path.join(DATA_ROOT, 'DAGM2007')
class_names = os.listdir(dataset_root)
for class_name in class_names:
states = os.listdir(os.path.join(dataset_root, class_name))
for state in states:
images = list()
mask = list()
files = os.listdir(os.path.join(dataset_root, class_name,state))
for f in files:
if 'PNG' in f[-3:]:
images.append(f)
files = os.listdir(os.path.join(dataset_root, class_name, state,'Label'))
for f in files:
if 'PNG' in f[-3:]:
mask.append(f)
normal_image_path_train = list()
normal_image_path_test = list()
normal_image_path = list()
abnormal_image_path = list()
abnormal_image_label = list()
for f in images:
id = f[-8:-4]
flag = 0
for y in mask:
if id in y:
abnormal_image_path.append(f)
abnormal_image_label.append(y)
flag = 1
break
if flag == 0:
normal_image_path.append(f)
if len(abnormal_image_path) != len(abnormal_image_label):
raise ValueError
length = len(abnormal_image_path)
normal_image_path_test = normal_image_path[:length]
normal_image_path_train = normal_image_path[length:]
target_root = '../datasets/DAGM_anomaly_detection'
train_root = os.path.join(target_root, class_name, 'train','good')
if not os.path.exists(train_root):
os.makedirs(train_root)
for f in normal_image_path_train:
image_data = cv2.imread(os.path.join(dataset_root, class_name, state,f))
cv2.imwrite(os.path.join(train_root,f), image_data)
test_root = os.path.join(target_root, class_name, 'test','good')
if not os.path.exists(test_root):
os.makedirs(test_root)
for f in normal_image_path_test:
image_data = cv2.imread(os.path.join(dataset_root, class_name, state,f))
cv2.imwrite(os.path.join(test_root,f), image_data)
test_root = os.path.join(target_root, class_name, 'test','defect')
if not os.path.exists(test_root):
os.makedirs(test_root)
for f in abnormal_image_path:
image_data = cv2.imread(os.path.join(dataset_root, class_name, state,f))
cv2.imwrite(os.path.join(test_root,f), image_data)
test_root = os.path.join(target_root, class_name, 'ground_truth','defect')
if not os.path.exists(test_root):
os.makedirs(test_root)
for f in mask:
image_data = cv2.imread(os.path.join(dataset_root, class_name, state,'Label',f))
cv2.imwrite(os.path.join(test_root,f), image_data)
print("Done") |