|
|
|
|
|
import inspect |
|
|
import os |
|
|
import random |
|
|
import sys |
|
|
import matplotlib.cm as cmx |
|
|
import matplotlib.colors as colors |
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib.legend as lgd |
|
|
import matplotlib.markers as mks |
|
|
|
|
|
def get_log_parsing_script(): |
|
|
dirname = os.path.dirname(os.path.abspath(inspect.getfile( |
|
|
inspect.currentframe()))) |
|
|
return dirname + '/parse_log.sh' |
|
|
|
|
|
def get_log_file_suffix(): |
|
|
return '.log' |
|
|
|
|
|
def get_chart_type_description_separator(): |
|
|
return ' vs. ' |
|
|
|
|
|
def is_x_axis_field(field): |
|
|
x_axis_fields = ['Iters', 'Seconds'] |
|
|
return field in x_axis_fields |
|
|
|
|
|
def create_field_index(): |
|
|
train_key = 'Train' |
|
|
test_key = 'Test' |
|
|
field_index = {train_key:{'Iters':0, 'Seconds':1, train_key + ' loss':2, |
|
|
train_key + ' learning rate':3}, |
|
|
test_key:{'Iters':0, 'Seconds':1, test_key + ' accuracy':2, |
|
|
test_key + ' loss':3}} |
|
|
fields = set() |
|
|
for data_file_type in field_index.keys(): |
|
|
fields = fields.union(set(field_index[data_file_type].keys())) |
|
|
fields = list(fields) |
|
|
fields.sort() |
|
|
return field_index, fields |
|
|
|
|
|
def get_supported_chart_types(): |
|
|
field_index, fields = create_field_index() |
|
|
num_fields = len(fields) |
|
|
supported_chart_types = [] |
|
|
for i in xrange(num_fields): |
|
|
if not is_x_axis_field(fields[i]): |
|
|
for j in xrange(num_fields): |
|
|
if i != j and is_x_axis_field(fields[j]): |
|
|
supported_chart_types.append('%s%s%s' % ( |
|
|
fields[i], get_chart_type_description_separator(), |
|
|
fields[j])) |
|
|
return supported_chart_types |
|
|
|
|
|
def get_chart_type_description(chart_type): |
|
|
supported_chart_types = get_supported_chart_types() |
|
|
chart_type_description = supported_chart_types[chart_type] |
|
|
return chart_type_description |
|
|
|
|
|
def get_data_file_type(chart_type): |
|
|
description = get_chart_type_description(chart_type) |
|
|
data_file_type = description.split()[0] |
|
|
return data_file_type |
|
|
|
|
|
def get_data_file(chart_type, path_to_log): |
|
|
return (os.path.basename(path_to_log) + '.' + |
|
|
get_data_file_type(chart_type).lower()) |
|
|
|
|
|
def get_field_descriptions(chart_type): |
|
|
description = get_chart_type_description(chart_type).split( |
|
|
get_chart_type_description_separator()) |
|
|
y_axis_field = description[0] |
|
|
x_axis_field = description[1] |
|
|
return x_axis_field, y_axis_field |
|
|
|
|
|
def get_field_indices(x_axis_field, y_axis_field): |
|
|
data_file_type = get_data_file_type(chart_type) |
|
|
fields = create_field_index()[0][data_file_type] |
|
|
return fields[x_axis_field], fields[y_axis_field] |
|
|
|
|
|
def load_data(data_file, field_idx0, field_idx1): |
|
|
data = [[], []] |
|
|
with open(data_file, 'r') as f: |
|
|
for line in f: |
|
|
line = line.strip() |
|
|
if line[0] != '#': |
|
|
fields = line.split() |
|
|
data[0].append(float(fields[field_idx0].strip())) |
|
|
data[1].append(float(fields[field_idx1].strip())) |
|
|
return data |
|
|
|
|
|
def random_marker(): |
|
|
markers = mks.MarkerStyle.markers |
|
|
num = len(markers.keys()) |
|
|
idx = random.randint(0, num - 1) |
|
|
return markers.keys()[idx] |
|
|
|
|
|
def get_data_label(path_to_log): |
|
|
label = path_to_log[path_to_log.rfind('/')+1 : path_to_log.rfind( |
|
|
get_log_file_suffix())] |
|
|
return label |
|
|
|
|
|
def get_legend_loc(chart_type): |
|
|
x_axis, y_axis = get_field_descriptions(chart_type) |
|
|
loc = 'lower right' |
|
|
if y_axis.find('accuracy') != -1: |
|
|
pass |
|
|
if y_axis.find('loss') != -1 or y_axis.find('learning rate') != -1: |
|
|
loc = 'upper right' |
|
|
return loc |
|
|
|
|
|
def plot_chart(chart_type, path_to_png, path_to_log_list): |
|
|
for path_to_log in path_to_log_list: |
|
|
os.system('%s %s' % (get_log_parsing_script(), path_to_log)) |
|
|
data_file = get_data_file(chart_type, path_to_log) |
|
|
x_axis_field, y_axis_field = get_field_descriptions(chart_type) |
|
|
x, y = get_field_indices(x_axis_field, y_axis_field) |
|
|
data = load_data(data_file, x, y) |
|
|
|
|
|
color = [random.random(), random.random(), random.random()] |
|
|
label = get_data_label(path_to_log) |
|
|
linewidth = 0.75 |
|
|
|
|
|
|
|
|
use_marker = True |
|
|
if not use_marker: |
|
|
plt.plot(data[0], data[1], label = label, color = color, |
|
|
linewidth = linewidth) |
|
|
else: |
|
|
marker = random_marker() |
|
|
plt.plot(data[0], data[1], label = label, color = color, |
|
|
marker = marker, linewidth = linewidth) |
|
|
legend_loc = get_legend_loc(chart_type) |
|
|
plt.legend(loc = legend_loc, ncol = 1) |
|
|
plt.title(get_chart_type_description(chart_type)) |
|
|
plt.xlabel(x_axis_field) |
|
|
plt.ylabel(y_axis_field) |
|
|
plt.savefig(path_to_png) |
|
|
plt.show() |
|
|
|
|
|
def print_help(): |
|
|
print """This script mainly serves as the basis of your customizations. |
|
|
Customization is a must. |
|
|
You can copy, paste, edit them in whatever way you want. |
|
|
Be warned that the fields in the training log may change in the future. |
|
|
You had better check the data files and change the mapping from field name to |
|
|
field index in create_field_index before designing your own plots. |
|
|
Usage: |
|
|
./plot_training_log.py chart_type[0-%s] /where/to/save.png /path/to/first.log ... |
|
|
Notes: |
|
|
1. Supporting multiple logs. |
|
|
2. Log file name must end with the lower-cased "%s". |
|
|
Supported chart types:""" % (len(get_supported_chart_types()) - 1, |
|
|
get_log_file_suffix()) |
|
|
supported_chart_types = get_supported_chart_types() |
|
|
num = len(supported_chart_types) |
|
|
for i in xrange(num): |
|
|
print ' %d: %s' % (i, supported_chart_types[i]) |
|
|
sys.exit() |
|
|
|
|
|
def is_valid_chart_type(chart_type): |
|
|
return chart_type >= 0 and chart_type < len(get_supported_chart_types()) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
if len(sys.argv) < 4: |
|
|
print_help() |
|
|
else: |
|
|
chart_type = int(sys.argv[1]) |
|
|
if not is_valid_chart_type(chart_type): |
|
|
print '%s is not a valid chart type.' % chart_type |
|
|
print_help() |
|
|
path_to_png = sys.argv[2] |
|
|
if not path_to_png.endswith('.png'): |
|
|
print 'Path must ends with png' % path_to_png |
|
|
sys.exit() |
|
|
path_to_logs = sys.argv[3:] |
|
|
for path_to_log in path_to_logs: |
|
|
if not os.path.exists(path_to_log): |
|
|
print 'Path does not exist: %s' % path_to_log |
|
|
sys.exit() |
|
|
if not path_to_log.endswith(get_log_file_suffix()): |
|
|
print 'Log file must end in %s.' % get_log_file_suffix() |
|
|
print_help() |
|
|
|
|
|
plot_chart(chart_type, path_to_png, path_to_logs) |
|
|
|