|
import os, re, numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.pyplot as plt_xsec
|
|
from datetime import datetime
|
|
|
|
input_log_file = './ewnet_logs_TRANS3_20240708.txt'
|
|
flag_all_xsections = True
|
|
prev_station = ''
|
|
|
|
now = datetime.now()
|
|
now_str = now.strftime('%Y%m%d_%H%M')
|
|
|
|
label_list = ['pave_layer1', 'pave_layer2', 'pave_layer3', 'pave_layer4', 'cut_ea', 'cut_rr', 'cut_br', 'cut_ditch', 'fill_subbed', 'fill_subbody', 'curb', 'above', 'below', 'pave_int', 'pave_surface', 'pave_subgrade', 'ground', 'pave_bottom', 'rr', 'br', 'slope', 'struct', 'steps']
|
|
color_list = [[0.8,0.8,0.8],[0.6,0.6,0.6],[0.4,0.4,0.4],[0.2,0.2,0.2],[0.8,0.4,0.2],[0.8,0.6,0.2],[0.8,0.8,0.2],[0.6,0.8,0.2],[0.3,0.8,0.3],[0.3,0.6,0.3],[0.3,0.4,0.3],[0.0,0.8,0.0],[0.6,0.0,0.0],[0.8,0.0,0.0],[1.0,0.0,0.0],[0.2,0.2,0.6],[0.0,1.0,0.0],[0.2,0.2,1.0],[0.4,0.2,1.0],[0.6,0.2,1.0],[0.2,0.8,0.6],[0.8,0.2,1.0],[1.0,0.2,1.0]]
|
|
|
|
|
|
if not os.path.exists('./graph'):
|
|
os.makedirs('./graph')
|
|
|
|
def draw_colorbox_list():
|
|
global label_list, color_list
|
|
|
|
fig, ax = plt.subplots(figsize=(9.2, 5))
|
|
ax.invert_yaxis()
|
|
ax.set_xlim(0, 1.5)
|
|
fig.set_size_inches(12, 7)
|
|
|
|
token_list = ['item1', 'item2', 'item3', 'item4', 'item5', 'item6']
|
|
for i, (colname, color) in enumerate(zip(label_list, color_list)):
|
|
width = 1.0 / len(label_list)
|
|
widths = [width] * len(token_list)
|
|
starts = width * i
|
|
rects = ax.barh(token_list, widths, left=starts, height=0.5, label=colname, color=color)
|
|
|
|
text_color = 'white' if np.max(color) < 0.4 else 'black'
|
|
ax.legend()
|
|
plt.savefig('./graph/box_colors.png')
|
|
plt.close()
|
|
|
|
def output_graph_matrics(index, tag, text):
|
|
global label_list, color_list
|
|
|
|
prediction = ''
|
|
tokens = []
|
|
polyline = []
|
|
geom_index = text.find('Geom:')
|
|
if geom_index >= 0:
|
|
pred_label = ''
|
|
label_index = text.find('Predicted: ')
|
|
if label_index >= 0:
|
|
pred = text[label_index + 11:geom_index]
|
|
labels = pred.split(', ')
|
|
if len(labels) > 0:
|
|
prediction = labels[0]
|
|
pred_label = labels[0] + '(0.3'
|
|
|
|
polyline_index = text.find('Polyline:')
|
|
if polyline_index > 0:
|
|
pred = text[geom_index + 6:polyline_index - 2]
|
|
polyline_text = text[polyline_index + 10:]
|
|
polyline = eval(polyline_text)
|
|
else:
|
|
pred = text[geom_index + 6:]
|
|
pred = pred.replace('[', '').replace(']', '')
|
|
pred = pred.replace(')', '').replace("'", '')
|
|
tokens = pred.split(',')
|
|
if len(tokens) <= 1:
|
|
tokens = pred.split(' ')
|
|
if len(tokens) > 0:
|
|
tokens.insert(0, pred_label)
|
|
last = tokens[-1]
|
|
if len(last) == 0:
|
|
tokens.pop()
|
|
else:
|
|
return
|
|
|
|
token_list = [token.split('(')[0] for token in tokens]
|
|
token_list = [token.replace(' ', '') for token in token_list]
|
|
ratios = [float(token.split('(')[1]) for token in tokens]
|
|
results = {token_list[0]: ratios}
|
|
|
|
labels = [label.replace(" ", "") for label in list(results.keys())]
|
|
data = np.array(list(results.values()))
|
|
data_cum = data.cumsum(axis=1)
|
|
token_colors = [color_list[label_list.index(label)] for label in token_list]
|
|
|
|
global plt_xsec, now_str, flag_all_xsections
|
|
if flag_all_xsections == False:
|
|
fig, ax = plt.subplots(figsize=(9.2, 5))
|
|
ax.invert_yaxis()
|
|
ax.xaxis.set_visible(False)
|
|
ax.set_xlim(0, np.sum(data, axis=1).max())
|
|
fig.set_size_inches(15, 0.5)
|
|
|
|
for i, (colname, color) in enumerate(zip(token_list, token_colors)):
|
|
widths = data[:, i]
|
|
starts = data_cum[:, i] - widths
|
|
if i > 0:
|
|
starts += 0.02
|
|
rects = ax.barh(labels, widths, left=starts, height=0.5, label=colname, color=color)
|
|
|
|
if i != 0:
|
|
text_color = 'white' if np.max(color) < 0.4 else 'black'
|
|
ax.bar_label(rects, label_type='center', color=text_color)
|
|
ax.legend(ncols=len(token_list), bbox_to_anchor=(0, 1), loc='lower right', fontsize='small')
|
|
|
|
tag = tag.replace(' ', '_')
|
|
tag = tag.replace(':', '')
|
|
|
|
if text.find('True') > 0:
|
|
plt.savefig(f'./graph/box_list_{now_str}_{tag}_{index}_T.png')
|
|
else:
|
|
plt.savefig(f'./graph/box_list_{now_str}_{tag}_{index}_F.png')
|
|
plt.close()
|
|
else:
|
|
if polyline[0] != polyline[-1]:
|
|
polyline.append(polyline[0])
|
|
x, y = zip(*polyline)
|
|
color = color_list[label_list.index(prediction)]
|
|
|
|
plt_xsec.fill(x, y, color=color)
|
|
centroid_x = sum(x) / len(x)
|
|
centroid_y = sum(y) / len(y)
|
|
area = 0.5 * abs(sum(x[i]*y[i+1] - x[i+1]*y[i] for i in range(len(polyline)-1)))
|
|
|
|
if prediction.find('pave') < 0:
|
|
plt_xsec.text(centroid_x, centroid_y, f'{prediction}={area:.2f}', horizontalalignment='center', verticalalignment='center', fontsize=5, color='black')
|
|
|
|
return prediction, area, token_list
|
|
|
|
output_stations = ['4+440.00000', '3+780.00000', '3+800.00000', '3+880.00000', '3+940.00000']
|
|
def output_logs(tag, equal='none'):
|
|
global input_log_file, plt_xsec, now_str, prev_station, flag_all_xsection, output_stations
|
|
|
|
text_list = []
|
|
logs = []
|
|
|
|
with open(input_log_file, 'r') as file:
|
|
for index, label in enumerate(label_list):
|
|
file.seek(0)
|
|
for line in file:
|
|
if flag_all_xsections == False and line.find(tag) < 0:
|
|
continue
|
|
tag_model = tag.split(' ')[0]
|
|
if flag_all_xsections == True and line.find(tag_model) < 0:
|
|
continue
|
|
if flag_all_xsections == False and line.find('Label: ' + label) < 0:
|
|
continue
|
|
line = line.replace('\n', '')
|
|
if equal == 'none':
|
|
text_list.append(line)
|
|
elif line.find(equal) > 0:
|
|
text_list.append(line)
|
|
if flag_all_xsections == False:
|
|
break
|
|
if flag_all_xsections:
|
|
break
|
|
|
|
if len(text_list) == 0:
|
|
return logs
|
|
|
|
def extract_station(text):
|
|
sta_index = text.find('Station:') + 9
|
|
end_index = text.find(',', sta_index)
|
|
return text[sta_index:end_index] if end_index != -1 else text[sta_index:]
|
|
|
|
text_list = sorted(text_list, key=extract_station)
|
|
station = ''
|
|
for index, text in enumerate(text_list):
|
|
sta_index = text.find('Station:')
|
|
equal_index = text.find('Equal: ')
|
|
equal_check = 'T' if text.find('True') > 0 else 'F'
|
|
|
|
if sta_index > 0 and equal_index > 0:
|
|
station = text[sta_index + 9:equal_index-2]
|
|
print(station)
|
|
|
|
try:
|
|
if len(output_stations) and output_stations.index(station) < 0:
|
|
continue
|
|
except Exception as e:
|
|
continue
|
|
|
|
if prev_station != station:
|
|
if len(prev_station) > 0:
|
|
plt_xsec.savefig(f'./graph/polygon_{now_str}_{tag}_{prev_station}_{equal_check}.png', dpi=300)
|
|
plt_xsec.close()
|
|
|
|
plt_xsec.figure()
|
|
plt_xsec.gca().set_xlim([-60, 60])
|
|
plt_xsec.gca().axis('equal')
|
|
plt_xsec.gca().text(0, 0, f'{station}', fontsize=12, color='black')
|
|
|
|
prev_station = station
|
|
|
|
text = text.replace('\n', '')
|
|
label, area, tokens = output_graph_matrics(index, tag, text)
|
|
log = {
|
|
'index': index,
|
|
'station': station,
|
|
'label': label,
|
|
'area': area,
|
|
'tokens': tokens
|
|
}
|
|
logs.append(log)
|
|
|
|
if index == len(text_list) - 1:
|
|
plt_xsec.savefig(f'./graph/polygon_{now_str}_{tag}_{prev_station}_{equal_check}.png', dpi=300)
|
|
plt_xsec.close()
|
|
|
|
return logs
|
|
|
|
def main():
|
|
draw_colorbox_list()
|
|
|
|
summary_log_file = open('./graph/summary_log.csv', 'a')
|
|
if summary_log_file is None:
|
|
return
|
|
summary_log_file.write(f'model, ground true, length, ground false, length\n')
|
|
|
|
tags = ['MLP [128, 64, 32]', 'MLP [64, 128, 64]', 'MLP [64, 128, 64, 32]', 'LSTM [128]', 'LSTM [128, 64, 32]', 'LSTM [256, 128, 64]', 'transformer 32', 'transformer 64', 'transformer 128', 'BERT']
|
|
for tag in tags:
|
|
print(tag)
|
|
if len(output_stations) > 0:
|
|
logs1 = output_logs(tag,)
|
|
continue
|
|
|
|
logs1 = output_logs(tag, 'Equal: True')
|
|
logs2 = output_logs(tag, 'Equal: False')
|
|
if len(logs1) == 0 or len(logs2) == 0:
|
|
continue
|
|
area1 = area2 = 0
|
|
area1 += sum([log['area'] for log in logs1])
|
|
area2 += sum([log['area'] for log in logs2])
|
|
log_record = f'{tag}, {area1}, {len(logs1)}, {area2}, {len(logs2)}'
|
|
summary_log_file.write(f'{log_record}\n')
|
|
|
|
if flag_all_xsections:
|
|
break
|
|
|
|
summary_log_file.close()
|
|
|
|
if __name__ == '__main__':
|
|
main()
|
|
|