File size: 5,752 Bytes
d945eeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from dataclasses import dataclass, field
from typing import Any, List, Optional

import open_clip
import torch
import torch.nn as nn
from jaxtyping import Float
from torch import Tensor
from torchvision.transforms import Normalize

from sf3d.models.network import get_activation
from sf3d.models.utils import BaseModule


@dataclass
class HeadSpec:
    name: str
    out_channels: int
    n_hidden_layers: int
    output_activation: Optional[str] = None
    output_bias: float = 0.0
    add_to_decoder_features: bool = False
    shape: Optional[list[int]] = None


class ClipBasedHeadEstimator(BaseModule):
    @dataclass
    class Config(BaseModule.Config):
        model: str = "ViT-B-32"
        pretrain: str = "laion2b_s34b_b79k"

        distribution: str = "beta"

        # ["mean", "mode", "sample", "sample_mean"]
        distribution_eval: str = "mode"

        activation: str = "relu"
        hidden_features: int = 512
        heads: List[HeadSpec] = field(default_factory=lambda: [])

    cfg: Config

    def configure(self):
        self.model, _, self.preprocess = open_clip.create_model_and_transforms(
            self.cfg.model, pretrained=self.cfg.pretrain
        )
        self.model.eval()

        # Do not add the weights in self.model to the optimizer
        for param in self.model.parameters():
            param.requires_grad = False

        assert len(self.cfg.heads) > 0
        heads = {}
        for head in self.cfg.heads:
            head_layers = []

            for i in range(head.n_hidden_layers):
                head_layers += [
                    nn.Linear(
                        self.cfg.hidden_features,
                        self.cfg.hidden_features,
                    ),
                    self.make_activation(self.cfg.activation),
                ]

            head_layers = [nn.Sequential(*head_layers)]
            head_layers += [
                nn.Sequential(
                    nn.Linear(
                        self.cfg.hidden_features,
                        self.cfg.hidden_features,
                    ),
                    self.make_activation(self.cfg.activation),
                    nn.Linear(self.cfg.hidden_features, 1),
                )
                for _ in range(2)
            ]
            heads[head.name] = nn.ModuleList(head_layers)
        self.heads = nn.ModuleDict(heads)

    def make_activation(self, activation):
        if activation == "relu":
            return nn.ReLU(inplace=True)
        elif activation == "silu":
            return nn.SiLU(inplace=True)
        else:
            raise NotImplementedError

    def forward(
        self,
        cond_image: Float[Tensor, "B 1 H W 3"],
        sample: bool = True,
    ) -> dict[str, Any]:
        # Run the model
        # Resize cond_image to 224
        cond_image = nn.functional.interpolate(
            cond_image.flatten(0, 1).permute(0, 3, 1, 2),
            size=(224, 224),
            mode="bilinear",
            align_corners=False,
        )
        cond_image = Normalize(
            mean=open_clip.constants.OPENAI_DATASET_MEAN,
            std=open_clip.constants.OPENAI_DATASET_STD,
        )(cond_image)
        image_features = self.model.encode_image(cond_image)

        # Run the heads
        outputs = {}

        for head_dict in self.cfg.heads:
            head_name = head_dict.name
            shared_head, d1_h, d2_h = self.heads[head_name]
            shared_features = shared_head(image_features)
            d1, d2 = [head(shared_features).squeeze(-1) for head in [d1_h, d2_h]]
            if self.cfg.distribution == "normal":
                mean = d1
                var = d2
                if mean.shape[-1] == 1:
                    outputs[head_name] = torch.distributions.Normal(
                        mean + head_dict.output_bias,
                        torch.nn.functional.softplus(var),
                    )
                else:
                    outputs[head_name] = torch.distributions.MultivariateNormal(
                        mean + head_dict.output_bias,
                        torch.nn.functional.softplus(var).diag_embed(),
                    )
            elif self.cfg.distribution == "beta":
                outputs[head_name] = torch.distributions.Beta(
                    torch.nn.functional.softplus(d1 + head_dict.output_bias),
                    torch.nn.functional.softplus(d2 + head_dict.output_bias),
                )
            else:
                raise NotImplementedError

        if sample:
            for head_dict in self.cfg.heads:
                head_name = head_dict.name
                dist = outputs[head_name]

                if self.cfg.distribution_eval == "mean":
                    out = dist.mean
                elif self.cfg.distribution_eval == "mode":
                    out = dist.mode
                elif self.cfg.distribution_eval == "sample_mean":
                    out = dist.sample([10]).mean(-1)
                else:
                    # use rsample if gradient is needed
                    out = dist.rsample() if self.training else dist.sample()

                outputs[head_name] = get_activation(head_dict.output_activation)(out)
                outputs[f"{head_name}_dist"] = dist

        for head in self.cfg.heads:
            if head.shape:
                if not sample:
                    raise ValueError(
                        "Cannot reshape non-sampled probabilisitic outputs"
                    )
                outputs[head.name] = outputs[head.name].reshape(*head.shape)

            if head.add_to_decoder_features:
                outputs[f"decoder_{head.name}"] = outputs[head.name]
                del outputs[head.name]

        return outputs