|
import os |
|
from findfile import find_files, find_dir |
|
|
|
filter_key_words = [ |
|
".py", |
|
".md", |
|
"readme", |
|
"log", |
|
"result", |
|
"zip", |
|
".state_dict", |
|
".model", |
|
".png", |
|
"acc_", |
|
"f1_", |
|
".backup", |
|
".bak", |
|
] |
|
|
|
|
|
def detect_infer_dataset(dataset_path, task="apc"): |
|
dataset_file = [] |
|
if isinstance(dataset_path, str) and os.path.isfile(dataset_path): |
|
dataset_file.append(dataset_path) |
|
return dataset_file |
|
|
|
for d in dataset_path: |
|
if not os.path.exists(d): |
|
search_path = find_dir( |
|
os.getcwd(), |
|
[d, task, "dataset"], |
|
exclude_key=filter_key_words, |
|
disable_alert=False, |
|
) |
|
dataset_file += find_files( |
|
search_path, |
|
[".inference", d], |
|
exclude_key=["train."] + filter_key_words, |
|
) |
|
else: |
|
dataset_file += find_files( |
|
d, [".inference", task], exclude_key=["train."] + filter_key_words |
|
) |
|
|
|
return dataset_file |
|
|