BuildingExtraction / Tools /GetTrainValTestCSV.py
KyanChen's picture
add model
ab01e4a
import os
import glob
import random
import pandas as pd
import cv2
import tqdm
import numpy as np
class GetTrainTestCSV:
def __init__(self, dataset_path_list, csv_name, img_format_list, negative_keep_rate=0.1):
self.data_path_list = dataset_path_list
self.img_format_list = img_format_list
self.negative_keep_rate = negative_keep_rate
self.save_path_csv = r'generate_dep_info'
os.makedirs(self.save_path_csv, exist_ok=True)
self.csv_name = csv_name
def get_csv(self, pattern):
def get_data_infos(img_path, img_format):
data_info = {'img': [], 'label': []}
img_file_list = glob.glob(img_path + '/*%s' % img_format)
assert len(img_file_list), 'No data in DATASET_PATH!'
for img_file in tqdm.tqdm(img_file_list):
label_file = img_file.replace(img_format, 'png').replace('imgs', 'labels')
if not os.path.exists(label_file):
label_file = 'None'
# if os.path.getsize(label_file) == 0:
# if np.random.random() < self.negative_keep_rate:
# data_info['img'].append(img_file)
# data_info['label'].append(label_file)
# continue
if pattern == 'test':
label_file = 'None'
data_info['img'].append(img_file)
data_info['label'].append(label_file)
return data_info
data_information = {'img': [], 'label': []}
for idx, data_dir in enumerate(self.data_path_list):
if len(self.data_path_list) == len(self.img_format_list):
img_format = self.img_format_list[idx]
else:
img_format = self.img_format_list[0]
assert os.path.exists(data_dir), 'No dir: ' + data_dir
img_path_list = glob.glob(data_dir+'/*{0}'.format(img_format))
# img folder
if len(img_path_list) == 0:
img_path_list = glob.glob(data_dir+'/*')
for img_path in img_path_list:
if os.path.isdir(img_path):
data_info = get_data_infos(img_path, img_format)
data_information['img'].extend(data_info['img'])
data_information['label'].extend(data_info['label'])
else:
data_info = get_data_infos(data_dir, img_format)
data_information['img'].extend(data_info['img'])
data_information['label'].extend(data_info['label'])
data_annotation = pd.DataFrame(data_information)
writer_name = self.save_path_csv + '/' + self.csv_name
data_annotation.to_csv(writer_name, index_label=False)
print(os.path.basename(writer_name) + ' file saves successfully!')
def generate_val_data_from_train_data(self, frac=0.1):
if os.path.exists(self.save_path_csv + '/' + self.csv_name):
data = pd.read_csv(self.save_path_csv + '/' + self.csv_name)
else:
raise Exception('no train data')
val_data = data.sample(frac=frac, replace=False)
train_data = data.drop(val_data.index)
val_data = val_data.reset_index(drop=True)
train_data = train_data.reset_index(drop=True)
writer_name = self.save_path_csv + '/' + self.csv_name
train_data.to_csv(writer_name, index_label=False)
writer_name = self.save_path_csv + '/' + self.csv_name.replace('train', 'val')
val_data.to_csv(writer_name, index_label=False)
def _get_file(self, in_path_list):
file_list = []
for file in in_path_list:
if os.path.isdir(os.path.abspath(file)):
files = glob.glob(file + '/*')
file_list.extend(self._get_file(files))
else:
file_list += [file]
return file_list
def get_csv_file(self, phase):
phases = ['seg', 'flow', 'od']
assert phase in phases, f'{phase} should in {phases}!'
file_list = self._get_file(self.data_path_list)
file_list = [x for x in file_list if x.split('.')[-1] in self.img_format_list]
assert len(file_list), 'No data in data_path_list!'
random.shuffle(file_list)
data_information = {}
if phase == 'seg':
data_information['img'] = file_list
data_information['label'] = [x.replace('img', 'label') for x in file_list]
elif phase == 'flow':
data_information['img1'] = file_list[:-1]
data_information['img2'] = file_list[1:]
elif phase == 'od':
data_information['img'] = file_list
data_information['label'] = [x.replace('tiff', 'txt').replace('jpg', 'txt').replace('png', 'txt') for x in file_list]
data_annotation = pd.DataFrame(data_information)
writer_name = self.save_path_csv + '/' + self.csv_name
data_annotation.to_csv(writer_name, index_label=False)
print(os.path.basename(writer_name) + ' file saves successfully!')
if __name__ == '__main__':
data_path_list = [
'D:/Code/ProjectOnGithub/STT/Data/val_samples/img'
]
csv_name = 'val_data.csv'
img_format_list = ['png']
getTrainTestCSV = GetTrainTestCSV(dataset_path_list=data_path_list, csv_name=csv_name, img_format_list=img_format_list)
getTrainTestCSV.get_csv_file(phase='seg')