Spaces:
Running
Running
File size: 1,928 Bytes
a25563f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import os
import json
import random
from config import DATA_ROOT
SDD_ROOT = os.path.join(DATA_ROOT, 'SDD_anomaly_detection')
class SDDSolver(object):
CLSNAMES = [
'SDD',
]
def __init__(self, root=SDD_ROOT, train_ratio=0.5):
self.root = root
self.meta_path = f'{root}/meta.json'
self.train_ratio = train_ratio
def run(self):
self.generate_meta_info()
def generate_meta_info(self):
info = dict(train={}, test={})
for cls_name in self.CLSNAMES:
cls_dir = f'{self.root}/{cls_name}'
for phase in ['train', 'test']:
cls_info = []
species = os.listdir(f'{cls_dir}/{phase}')
for specie in species:
is_abnormal = True if specie not in ['good'] else False
img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None
img_names.sort()
mask_names.sort() if mask_names is not None else None
for idx, img_name in enumerate(img_names):
info_img = dict(
img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '',
cls_name=cls_name,
specie_name=specie,
anomaly=1 if is_abnormal else 0,
)
cls_info.append(info_img)
info[phase][cls_name] = cls_info
with open(self.meta_path, 'w') as f:
f.write(json.dumps(info, indent=4) + "\n")
if __name__ == '__main__':
runner = SDDSolver(root=SDD_ROOT)
runner.run()
|