File size: 1,320 Bytes
520a6ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
881be7a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
import torch.nn as nn
import torch.nn.functional as F

from copy import deepcopy
from clip.model import VisionTransformer
from typing import Tuple


class CSD(nn.Module):
    def __init__(
        self,
        vit_input_resolution: int = 224,
        vit_patch_size: int = 14,
        vit_width: int = 1024,
        vit_layers: int = 768,
        vit_heads: int = 16,
        vit_output_dim: int = 768,
    ) -> None:
        super(CSD, self).__init__()

        self.backbone = VisionTransformer(
            input_resolution=vit_input_resolution,
            patch_size=vit_patch_size,
            width=vit_width,
            layers=vit_layers,
            heads=vit_heads,
            output_dim=vit_output_dim,
        )

        self.last_layer_style = deepcopy(self.backbone.proj)
        self.last_layer_content = deepcopy(self.backbone.proj)
        self.backbone.proj = None

    def forward(self, pixel_values: torch.Tensor) -> Tuple[torch.Tensor]:
        features = self.backbone(pixel_values)

        style_output = features @ self.last_layer_style
        style_output = F.normalize(style_output, dim=1, p=2)

        content_output = features @ self.last_layer_content
        content_output = F.normalize(content_output, dim=1, p=2)

        return features, style_output, content_output