Spaces:
Running
Running
import csv | |
import matplotlib.pyplot as plt | |
import numpy as np | |
def _decode_train_csv(csv_path): | |
epochs = [] | |
train_loss = [] | |
val_loss = [] | |
dice = [] | |
with open(csv_path) as csv_file: | |
csv_reader = csv.DictReader(csv_file) | |
for row in csv_reader: | |
epochs.append(row['step']) | |
train_loss.append(row['train_loss']) | |
val_loss.append(row['val_loss']) | |
dice.append(row['dice_score']) | |
return (np.array(epochs, dtype=np.uint), np.array(train_loss, dtype=np.float32), | |
np.array(val_loss, dtype=np.float32), np.array(dice, dtype=np.float32)) | |
def plot_train_data(csv_path, store = None, show=True, steps_in_epoch = -1): | |
data = _decode_train_csv(csv_path) | |
plt.plot(data[0], data[1], label = 'Training Loss') | |
plt.plot(data[0], data[2], label = 'Validation loss') | |
plt.plot(data[0], data[3], label = 'Dice Score') | |
if(steps_in_epoch > 0): | |
vlines = [x for x in range(0, data[0][-1]) if x % steps_in_epoch == 0] | |
plt.vlines(vlines, ymin = -0.2, ymax = -0.05) | |
plt.ylim(-0.1, 1.1) | |
plt.ylabel('Training loss') | |
plt.xlabel('Train Step') | |
plt.legend(loc="upper left") | |
if(store): | |
plt.savefig(store) | |
if(show): | |
plt.show() | |
def plot_multiple_val_losses(names, csvs): | |
for name, csv in zip(names, csvs): | |
data = _decode_train_csv(csv) | |
plt.plot(data[0], data[2], label = name) | |
plt.ylim(-0.1, 1.1) | |
plt.xlim(0, 7000) | |
plt.ylabel('Validation Loss') | |
plt.xlabel('Train Step') | |
plt.legend(loc="upper left") | |
plt.show() | |
if __name__ == "__main__": | |
#path = "D:\\Repos\\LungTumorSegmentation\\models\\metrics.csv" | |
#plot_train_data(path) | |
names = ['Base 16: Multiplier: 2x', 'Base 64: Multiplier: 2x', 'Base 128: Multiplier: 2x', 'Base 64: Multiplier: 3.5x', 'Base 192: Multiplier: 1.5x'] | |
csvs = ["C:\\Users\\vemun\\Desktop\\Plots\\16_2.csv", "C:\\Users\\vemun\\Desktop\\Plots\\64_2.csv", "C:\\Users\\vemun\\Desktop\\Plots\\128_2.csv", "C:\\Users\\vemun\\Desktop\\Plots\\64_3_5.csv", "C:\\Users\\vemun\\Desktop\\Plots\\192_1_5.csv"] | |
plot_multiple_val_losses(names, csvs) |