earthwork-net-model / extract_ewlog.py
mac999's picture
Upload 7 files
af359c9 verified
raw
history blame
8.13 kB
import os, re, numpy as np
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt_xsec
from datetime import datetime
input_log_file = './ewnet_logs_TRANS3_20240708.txt'
flag_all_xsections = True
prev_station = ''
now = datetime.now()
now_str = now.strftime('%Y%m%d_%H%M')
label_list = ['pave_layer1', 'pave_layer2', 'pave_layer3', 'pave_layer4', 'cut_ea', 'cut_rr', 'cut_br', 'cut_ditch', 'fill_subbed', 'fill_subbody', 'curb', 'above', 'below', 'pave_int', 'pave_surface', 'pave_subgrade', 'ground', 'pave_bottom', 'rr', 'br', 'slope', 'struct', 'steps']
color_list = [[0.8,0.8,0.8],[0.6,0.6,0.6],[0.4,0.4,0.4],[0.2,0.2,0.2],[0.8,0.4,0.2],[0.8,0.6,0.2],[0.8,0.8,0.2],[0.6,0.8,0.2],[0.3,0.8,0.3],[0.3,0.6,0.3],[0.3,0.4,0.3],[0.0,0.8,0.0],[0.6,0.0,0.0],[0.8,0.0,0.0],[1.0,0.0,0.0],[0.2,0.2,0.6],[0.0,1.0,0.0],[0.2,0.2,1.0],[0.4,0.2,1.0],[0.6,0.2,1.0],[0.2,0.8,0.6],[0.8,0.2,1.0],[1.0,0.2,1.0]]
# make folder
if not os.path.exists('./graph'):
os.makedirs('./graph')
def draw_colorbox_list():
global label_list, color_list
fig, ax = plt.subplots(figsize=(9.2, 5))
ax.invert_yaxis()
ax.set_xlim(0, 1.5)
fig.set_size_inches(12, 7)
token_list = ['item1', 'item2', 'item3', 'item4', 'item5', 'item6']
for i, (colname, color) in enumerate(zip(label_list, color_list)):
width = 1.0 / len(label_list)
widths = [width] * len(token_list)
starts = width * i
rects = ax.barh(token_list, widths, left=starts, height=0.5, label=colname, color=color)
text_color = 'white' if np.max(color) < 0.4 else 'black'
ax.legend()
plt.savefig('./graph/box_colors.png')
plt.close()
def output_graph_matrics(index, tag, text):
global label_list, color_list
prediction = ''
tokens = []
polyline = []
geom_index = text.find('Geom:')
if geom_index >= 0:
pred_label = ''
label_index = text.find('Predicted: ')
if label_index >= 0:
pred = text[label_index + 11:geom_index]
labels = pred.split(', ')
if len(labels) > 0:
prediction = labels[0]
pred_label = labels[0] + '(0.3'
polyline_index = text.find('Polyline:')
if polyline_index > 0:
pred = text[geom_index + 6:polyline_index - 2]
polyline_text = text[polyline_index + 10:]
polyline = eval(polyline_text)
else:
pred = text[geom_index + 6:]
pred = pred.replace('[', '').replace(']', '')
pred = pred.replace(')', '').replace("'", '')
tokens = pred.split(',')
if len(tokens) <= 1:
tokens = pred.split(' ')
if len(tokens) > 0:
tokens.insert(0, pred_label)
last = tokens[-1]
if len(last) == 0:
tokens.pop()
else:
return
token_list = [token.split('(')[0] for token in tokens]
token_list = [token.replace(' ', '') for token in token_list]
ratios = [float(token.split('(')[1]) for token in tokens]
results = {token_list[0]: ratios}
labels = [label.replace(" ", "") for label in list(results.keys())]
data = np.array(list(results.values()))
data_cum = data.cumsum(axis=1)
token_colors = [color_list[label_list.index(label)] for label in token_list]
global plt_xsec, now_str, flag_all_xsections
if flag_all_xsections == False:
fig, ax = plt.subplots(figsize=(9.2, 5))
ax.invert_yaxis()
ax.xaxis.set_visible(False)
ax.set_xlim(0, np.sum(data, axis=1).max())
fig.set_size_inches(15, 0.5)
for i, (colname, color) in enumerate(zip(token_list, token_colors)):
widths = data[:, i]
starts = data_cum[:, i] - widths
if i > 0:
starts += 0.02
rects = ax.barh(labels, widths, left=starts, height=0.5, label=colname, color=color)
if i != 0:
text_color = 'white' if np.max(color) < 0.4 else 'black'
ax.bar_label(rects, label_type='center', color=text_color)
ax.legend(ncols=len(token_list), bbox_to_anchor=(0, 1), loc='lower right', fontsize='small')
tag = tag.replace(' ', '_')
tag = tag.replace(':', '')
if text.find('True') > 0:
plt.savefig(f'./graph/box_list_{now_str}_{tag}_{index}_T.png')
else:
plt.savefig(f'./graph/box_list_{now_str}_{tag}_{index}_F.png')
plt.close()
else:
if polyline[0] != polyline[-1]:
polyline.append(polyline[0])
x, y = zip(*polyline)
color = color_list[label_list.index(prediction)]
plt_xsec.fill(x, y, color=color)
centroid_x = sum(x) / len(x)
centroid_y = sum(y) / len(y)
area = 0.5 * abs(sum(x[i]*y[i+1] - x[i+1]*y[i] for i in range(len(polyline)-1)))
if prediction.find('pave') < 0:
plt_xsec.text(centroid_x, centroid_y, f'{prediction}={area:.2f}', horizontalalignment='center', verticalalignment='center', fontsize=5, color='black')
return prediction, area, token_list
output_stations = ['4+440.00000', '3+780.00000', '3+800.00000', '3+880.00000', '3+940.00000']
def output_logs(tag, equal='none'):
global input_log_file, plt_xsec, now_str, prev_station, flag_all_xsection, output_stations
text_list = []
logs = []
with open(input_log_file, 'r') as file:
for index, label in enumerate(label_list):
file.seek(0)
for line in file:
if flag_all_xsections == False and line.find(tag) < 0:
continue
tag_model = tag.split(' ')[0]
if flag_all_xsections == True and line.find(tag_model) < 0:
continue
if flag_all_xsections == False and line.find('Label: ' + label) < 0:
continue
line = line.replace('\n', '')
if equal == 'none':
text_list.append(line)
elif line.find(equal) > 0:
text_list.append(line)
if flag_all_xsections == False:
break
if flag_all_xsections:
break
if len(text_list) == 0:
return logs
def extract_station(text):
sta_index = text.find('Station:') + 9 # Start of station value
end_index = text.find(',', sta_index)
return text[sta_index:end_index] if end_index != -1 else text[sta_index:]
text_list = sorted(text_list, key=extract_station)
station = ''
for index, text in enumerate(text_list):
sta_index = text.find('Station:')
equal_index = text.find('Equal: ')
equal_check = 'T' if text.find('True') > 0 else 'F'
if sta_index > 0 and equal_index > 0:
station = text[sta_index + 9:equal_index-2]
print(station)
try:
if len(output_stations) and output_stations.index(station) < 0:
continue
except Exception as e:
continue
if prev_station != station:
if len(prev_station) > 0:
plt_xsec.savefig(f'./graph/polygon_{now_str}_{tag}_{prev_station}_{equal_check}.png', dpi=300)
plt_xsec.close()
plt_xsec.figure()
plt_xsec.gca().set_xlim([-60, 60])
plt_xsec.gca().axis('equal')
plt_xsec.gca().text(0, 0, f'{station}', fontsize=12, color='black')
prev_station = station
text = text.replace('\n', '')
label, area, tokens = output_graph_matrics(index, tag, text)
log = {
'index': index,
'station': station,
'label': label,
'area': area,
'tokens': tokens
}
logs.append(log)
if index == len(text_list) - 1:
plt_xsec.savefig(f'./graph/polygon_{now_str}_{tag}_{prev_station}_{equal_check}.png', dpi=300)
plt_xsec.close()
return logs
def main():
draw_colorbox_list()
summary_log_file = open('./graph/summary_log.csv', 'a')
if summary_log_file is None:
return
summary_log_file.write(f'model, ground true, length, ground false, length\n')
tags = ['MLP [128, 64, 32]', 'MLP [64, 128, 64]', 'MLP [64, 128, 64, 32]', 'LSTM [128]', 'LSTM [128, 64, 32]', 'LSTM [256, 128, 64]', 'transformer 32', 'transformer 64', 'transformer 128', 'BERT']
for tag in tags:
print(tag)
if len(output_stations) > 0:
logs1 = output_logs(tag,)
continue
logs1 = output_logs(tag, 'Equal: True')
logs2 = output_logs(tag, 'Equal: False')
if len(logs1) == 0 or len(logs2) == 0:
continue
area1 = area2 = 0
area1 += sum([log['area'] for log in logs1])
area2 += sum([log['area'] for log in logs2])
log_record = f'{tag}, {area1}, {len(logs1)}, {area2}, {len(logs2)}'
summary_log_file.write(f'{log_record}\n')
if flag_all_xsections:
break
summary_log_file.close()
if __name__ == '__main__':
main()