File size: 4,501 Bytes
4e46a55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import time

from .constants import *
from utilities.device import get_device
from .lr_scheduling import get_lr
import torch.nn.functional as F

def train_epoch(cur_epoch, model, dataloader, loss, opt, lr_scheduler=None, print_modulus=1):
    out = -1
    model.train()
    for batch_num, batch in enumerate(dataloader):
        time_before = time.time()
        opt.zero_grad()

        feature_semantic_list = [] 
        for feature_semantic in batch["semanticList"]:
            feature_semantic_list.append( feature_semantic.to(get_device()) )

        feature_scene_offset = batch["scene_offset"].to(get_device())
        feature_motion = batch["motion"].to(get_device())
        feature_emotion = batch["emotion"].to(get_device())

        feature_note_density = batch["note_density"].to(get_device())
        feature_loudness = batch["loudness"].to(get_device())

        y = model(
                  feature_semantic_list, 
                  feature_scene_offset,
                  feature_motion,
                  feature_emotion)
        
        y   = y.reshape(y.shape[0] * y.shape[1], -1)
        
        feature_loudness = feature_loudness.flatten().reshape(-1,1) # (300, 1)
        feature_note_density = feature_note_density.flatten().reshape(-1,1) # (300, 1)        
        feature_combined = torch.cat((feature_note_density, feature_loudness), dim=1) # (300, 2)

        out = loss.forward(y, feature_combined)
        out.backward()
        opt.step()
        
        if(lr_scheduler is not None):
            lr_scheduler.step()
        time_after = time.time()
        time_took = time_after - time_before
        
        if((batch_num+1) % print_modulus == 0):
            print(SEPERATOR)
            print("Epoch", cur_epoch, " Batch", batch_num+1, "/", len(dataloader))
            print("LR:", get_lr(opt))
            print("Train loss:", float(out))
            print("")
            print("Time (s):", time_took)
            print(SEPERATOR)
            print("")
    return

def eval_model(model, dataloader, loss):
    model.eval()
    
    avg_rmse     = -1
    avg_loss    = -1
    avg_rmse_note_density     = -1
    avg_rmse_loudness     = -1
    with torch.set_grad_enabled(False):
        n_test      = len(dataloader)
        
        sum_loss   = 0.0
        
        sum_rmse    = 0.0
        sum_rmse_note_density = 0.0
        sum_rmse_loudness = 0.0

        for batch in dataloader:
            feature_semantic_list = [] 
            for feature_semantic in batch["semanticList"]:
                feature_semantic_list.append( feature_semantic.to(get_device()) )

            feature_scene_offset = batch["scene_offset"].to(get_device())
            feature_motion = batch["motion"].to(get_device())
            feature_emotion = batch["emotion"].to(get_device())
            feature_loudness = batch["loudness"].to(get_device())
            feature_note_density = batch["note_density"].to(get_device())
            
            y = model(
                    feature_semantic_list, 
                    feature_scene_offset,
                    feature_motion,
                    feature_emotion)
            
            y   = y.reshape(y.shape[0] * y.shape[1], -1)

            feature_loudness = feature_loudness.flatten().reshape(-1,1) # (300, 1)
            feature_note_density = feature_note_density.flatten().reshape(-1,1) # (300, 1)        
            feature_combined = torch.cat((feature_note_density, feature_loudness), dim=1) # (300, 2)

            mse = F.mse_loss(y, feature_combined)
            rmse = torch.sqrt(mse)
            sum_rmse += float(rmse)

            y_note_density, y_loudness = torch.split(y, split_size_or_sections=1, dim=1)

            mse_note_density = F.mse_loss(y_note_density, feature_note_density)
            rmse_note_density = torch.sqrt(mse_note_density)
            sum_rmse_note_density += float(rmse_note_density)
            
            mse_loudness = F.mse_loss(y_loudness, feature_loudness)
            rmse_loudness = torch.sqrt(mse_loudness)
            sum_rmse_loudness += float(rmse_loudness)

            out = loss.forward(y, feature_combined)
            sum_loss += float(out)
            
        avg_loss    = sum_loss / n_test
        avg_rmse     = sum_rmse / n_test
        avg_rmse_note_density     = sum_rmse_note_density / n_test
        avg_rmse_loudness     = sum_rmse_loudness / n_test

    return avg_loss, avg_rmse, avg_rmse_note_density, avg_rmse_loudness