Spaces:
Runtime error
Runtime error
import os | |
import numpy as np | |
def parse_model_cfg(path): | |
# Parse the yolo *.cfg file and return module definitions path may be 'cfg/yolov3.cfg', 'yolov3.cfg', or 'yolov3' | |
if not path.endswith('.cfg'): # add .cfg suffix if omitted | |
path += '.cfg' | |
if not os.path.exists(path) and os.path.exists('cfg' + os.sep + path): # add cfg/ prefix if omitted | |
path = 'cfg' + os.sep + path | |
with open(path, 'r') as f: | |
lines = f.read().split('\n') | |
lines = [x for x in lines if x and not x.startswith('#')] | |
lines = [x.rstrip().lstrip() for x in lines] # get rid of fringe whitespaces | |
mdefs = [] # module definitions | |
for line in lines: | |
if line.startswith('['): # This marks the start of a new block | |
mdefs.append({}) | |
mdefs[-1]['type'] = line[1:-1].rstrip() | |
if mdefs[-1]['type'] == 'convolutional': | |
mdefs[-1]['batch_normalize'] = 0 # pre-populate with zeros (may be overwritten later) | |
else: | |
key, val = line.split("=") | |
key = key.rstrip() | |
if key == 'anchors': # return nparray | |
mdefs[-1][key] = np.array([float(x) for x in val.split(',')]).reshape((-1, 2)) # np anchors | |
elif (key in ['from', 'layers', 'mask']) or (key == 'size' and ',' in val): # return array | |
mdefs[-1][key] = [int(x) for x in val.split(',')] | |
else: | |
val = val.strip() | |
if val.isnumeric(): # return int or float | |
mdefs[-1][key] = int(val) if (int(val) - float(val)) == 0 else float(val) | |
else: | |
mdefs[-1][key] = val # return string | |
# Check all fields are supported | |
supported = ['type', 'batch_normalize', 'filters', 'size', 'stride', 'pad', 'activation', 'layers', 'groups', | |
'from', 'mask', 'anchors', 'classes', 'num', 'jitter', 'ignore_thresh', 'truth_thresh', 'random', | |
'stride_x', 'stride_y', 'weights_type', 'weights_normalization', 'scale_x_y', 'beta_nms', 'nms_kind', | |
'iou_loss', 'iou_normalizer', 'cls_normalizer', 'iou_thresh', 'atoms', 'na', 'nc'] | |
f = [] # fields | |
for x in mdefs[1:]: | |
[f.append(k) for k in x if k not in f] | |
u = [x for x in f if x not in supported] # unsupported fields | |
assert not any(u), "Unsupported fields %s in %s. See https://github.com/ultralytics/yolov3/issues/631" % (u, path) | |
return mdefs | |
def parse_data_cfg(path): | |
# Parses the data configuration file | |
if not os.path.exists(path) and os.path.exists('data' + os.sep + path): # add data/ prefix if omitted | |
path = 'data' + os.sep + path | |
with open(path, 'r') as f: | |
lines = f.readlines() | |
options = dict() | |
for line in lines: | |
line = line.strip() | |
if line == '' or line.startswith('#'): | |
continue | |
key, val = line.split('=') | |
options[key.strip()] = val.strip() | |
return options | |