File size: 822 Bytes
bd282c4
 
948bfd2
bd282c4
 
 
 
 
948bfd2
bd282c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin

from manipulate_model.encoder.encoder import Encoder
from manipulate_model.decoder.decoder import Decoder


class Model(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config):
        super(Model, self).__init__()
        self.config = config

        self.encoder = Encoder(self.config)
        self.config.model.decoder.temporal_dim = self.encoder.get_temporal_dim()
        self.config.model.decoder.encoding_dim = self.encoder.get_encoding_dim()
        self.decoder = Decoder(self.config)

    def forward(self, x):
        if self.config.model.encoder_freeze:
            with torch.no_grad():
                x = self.encoder(x)
        else:
            x = self.encoder(x)
        x = self.decoder(x)
        return x