|
|
|
|
|
from __future__ import absolute_import |
|
|
|
import os |
|
import sys |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
|
|
__all__ = ["Logger", "LoggerMonitor", "savefig"] |
|
|
|
|
|
def savefig(fname, dpi=None): |
|
dpi = 150 if dpi == None else dpi |
|
plt.savefig(fname, dpi=dpi) |
|
|
|
|
|
def plot_overlap(logger, names=None): |
|
names = logger.names if names == None else names |
|
numbers = logger.numbers |
|
for _, name in enumerate(names): |
|
x = np.arange(len(numbers[name])) |
|
plt.plot(x, np.asarray(numbers[name])) |
|
return [logger.title + "(" + name + ")" for name in names] |
|
|
|
|
|
class Logger(object): |
|
"""Save training process to log file with simple plot function.""" |
|
|
|
def __init__(self, fpath, title=None, resume=False): |
|
self.file = None |
|
self.resume = resume |
|
self.title = "" if title == None else title |
|
if fpath is not None: |
|
if resume: |
|
self.file = open(fpath, "r") |
|
name = self.file.readline() |
|
self.names = name.rstrip().split("\t") |
|
self.numbers = {} |
|
for _, name in enumerate(self.names): |
|
self.numbers[name] = [] |
|
|
|
for numbers in self.file: |
|
numbers = numbers.rstrip().split("\t") |
|
for i in range(0, len(numbers)): |
|
self.numbers[self.names[i]].append(numbers[i]) |
|
self.file.close() |
|
self.file = open(fpath, "a") |
|
else: |
|
self.file = open(fpath, "w") |
|
|
|
def set_names(self, names): |
|
if self.resume: |
|
pass |
|
|
|
self.numbers = {} |
|
self.names = names |
|
for _, name in enumerate(self.names): |
|
self.file.write(name) |
|
self.file.write("\t") |
|
self.numbers[name] = [] |
|
self.file.write("\n") |
|
self.file.flush() |
|
|
|
def append(self, numbers): |
|
assert len(self.names) == len(numbers), "Numbers do not match names" |
|
for index, num in enumerate(numbers): |
|
self.file.write("{0:.6f}".format(num)) |
|
self.file.write("\t") |
|
self.numbers[self.names[index]].append(num) |
|
self.file.write("\n") |
|
self.file.flush() |
|
|
|
def plot(self, names=None): |
|
names = self.names if names == None else names |
|
numbers = self.numbers |
|
for _, name in enumerate(names): |
|
x = np.arange(len(numbers[name])) |
|
plt.plot(x, np.asarray(numbers[name])) |
|
plt.legend([self.title + "(" + name + ")" for name in names]) |
|
plt.grid(True) |
|
|
|
def close(self): |
|
if self.file is not None: |
|
self.file.close() |
|
|
|
|
|
class LoggerMonitor(object): |
|
"""Load and visualize multiple logs.""" |
|
|
|
def __init__(self, paths): |
|
"""paths is a distionary with {name:filepath} pair""" |
|
self.loggers = [] |
|
for title, path in paths.items(): |
|
logger = Logger(path, title=title, resume=True) |
|
self.loggers.append(logger) |
|
|
|
def plot(self, names=None): |
|
plt.figure() |
|
plt.subplot(121) |
|
legend_text = [] |
|
for logger in self.loggers: |
|
legend_text += plot_overlap(logger, names) |
|
plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0) |
|
plt.grid(True) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
paths = { |
|
"resadvnet20": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt", |
|
"resadvnet32": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt", |
|
"resadvnet44": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt", |
|
} |
|
|
|
field = ["Valid Acc."] |
|
|
|
monitor = LoggerMonitor(paths) |
|
monitor.plot(names=field) |
|
savefig("test.eps") |
|
|