File size: 3,272 Bytes
cfb7702
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import torch.nn as nn

import models
from models.utils import get_activation
from models.network_utils import get_encoding, get_mlp
from systems.utils import update_module_step


@models.register('volume-radiance')
class VolumeRadiance(nn.Module):
    def __init__(self, config):
        super(VolumeRadiance, self).__init__()
        self.config = config
        self.with_viewdir = False #self.config.get('wo_viewdir', False)
        self.n_dir_dims = self.config.get('n_dir_dims', 3)
        self.n_output_dims = 3

        if self.with_viewdir:
            encoding = get_encoding(self.n_dir_dims, self.config.dir_encoding_config)
            self.n_input_dims = self.config.input_feature_dim + encoding.n_output_dims
            # self.network_base = get_mlp(self.config.input_feature_dim, self.n_output_dims, self.config.mlp_network_config)   
        else:
            encoding = None
            self.n_input_dims = self.config.input_feature_dim
            
        network = get_mlp(self.n_input_dims, self.n_output_dims, self.config.mlp_network_config)    
        self.encoding = encoding
        self.network = network
    
    def forward(self, features, dirs, *args):

        # features = features.detach()
        if self.with_viewdir:
            dirs = (dirs + 1.) / 2. # (-1, 1) => (0, 1)
            dirs_embd = self.encoding(dirs.view(-1, self.n_dir_dims))
            network_inp = torch.cat([features.view(-1, features.shape[-1]), dirs_embd] + [arg.view(-1, arg.shape[-1]) for arg in args], dim=-1)
            # network_inp_base = torch.cat([features.view(-1, features.shape[-1])] + [arg.view(-1, arg.shape[-1]) for arg in args], dim=-1)
            color = self.network(network_inp).view(*features.shape[:-1], self.n_output_dims).float()
            # color_base = self.network_base(network_inp_base).view(*features.shape[:-1], self.n_output_dims).float()
            # color = color + color_base
        else:
            network_inp = torch.cat([features.view(-1, features.shape[-1])] + [arg.view(-1, arg.shape[-1]) for arg in args], dim=-1)
            color = self.network(network_inp).view(*features.shape[:-1], self.n_output_dims).float()

        if 'color_activation' in self.config:
            color = get_activation(self.config.color_activation)(color)
        return color

    def update_step(self, epoch, global_step):
        update_module_step(self.encoding, epoch, global_step)

    def regularizations(self, out):
        return {}


@models.register('volume-color')
class VolumeColor(nn.Module):
    def __init__(self, config):
        super(VolumeColor, self).__init__()
        self.config = config
        self.n_output_dims = 3
        self.n_input_dims = self.config.input_feature_dim
        network = get_mlp(self.n_input_dims, self.n_output_dims, self.config.mlp_network_config)
        self.network = network
    
    def forward(self, features, *args):
        network_inp = features.view(-1, features.shape[-1])
        color = self.network(network_inp).view(*features.shape[:-1], self.n_output_dims).float()
        if 'color_activation' in self.config:
            color = get_activation(self.config.color_activation)(color)
        return color

    def regularizations(self, out):
        return {}