File size: 2,767 Bytes
0c84ee8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# -*- coding: utf-8 -*-
"""
Classes for controling machine learning processes
"""
import numpy as np
import math
import matplotlib.pyplot as plt
import csv


class TrainingPlot:
    """
    Creating live plot during training
    REUIRES notebook backend: %matplotlib notebook
    @TODO Migrate to Tensorboard
    """
    train_loss = []
    train_acc = []
    valid_acc = []
    test_iter = 0
    loss_iter = 0
    interval = 0
    ax1 = None
    ax2 = None
    fig = None

    def __init__(self, steps, test_itr, loss_itr):
        self.test_iter = test_itr
        self.loss_iter = loss_itr
        self.interval = steps

        self.fig, self.ax1 = plt.subplots()
        self.ax2 = self.ax1.twinx()
        self.ax1.set_autoscaley_on(True)
        plt.ion()

        self._update_plot()

        # Description
        self.ax1.set_xlabel('Iteration')
        self.ax1.set_ylabel('Train Loss')
        self.ax2.set_ylabel('Valid. Accuracy')

        # Axes limits
        self.ax1.set_ylim([0,10])

    def _update_plot(self):
        self.fig.canvas.draw()

    def update_loss(self, loss_train, index):
        self.trainLoss.append(loss_train)
        if len(self.train_loss) == 1:
            self.ax1.set_ylim([0, min(10, math.ceil(loss_train))])
        self.ax1.plot(self.lossInterval * np.arange(len(self.train_loss)),
                      self.train_loss, 'b', linewidth=1.0)

        self.updatePlot()

    def update_acc(self, acc_val, acc_train, index):
        self.validAcc.append(acc_val)
        self.trainAcc.append(acc_train)

        self.ax2.plot(self.test_iter * np.arange(len(self.valid_acc)),
                      self.valid_acc, 'r', linewidth=1.0)
        self.ax2.plot(self.test_iter * np.arange(len(self.train_acc)),
                      self.train_acc, 'g',linewidth=1.0)

        self.ax2.set_title('Valid. Accuracy: {:.4f}'.format(self.valid_acc[-1]))

        self.updatePlot()


class DataSet:
    """Class for training data and feeding train function."""
    images = None
    labels = None
    length = 0
    index = 0

    def __init__(self, img, lbl):
        self.images = img
        self.labels = lbl
        self.length = len(img)
        self.index = 0

    def next_batch(self, batch_size):
        """Return the next batch from the data set."""
        start = self.index
        self.index += batch_size

        if self.index > self.length:
            # Shuffle the data
            perm = np.arange(self.length)
            np.random.shuffle(perm)
            self.images = self.images[perm]
            self.labels = self.labels[perm]
            # Start next epoch
            start = 0
            self.index = batch_size

        end = self.index
        return self.images[start:end], self.labels[start:end]