File size: 4,041 Bytes
7dd9869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import numpy as np

import torch
import torch.nn as nn
from transformers import AutoConfig, T5EncoderModel

from .nn import SiLU, linear, timestep_embedding


class TransformerNetModel(nn.Module):
    def __init__(
        self,
        in_channels=32,
        model_channels=128,
        dropout=0.1,
        config_name="QizhiPei/biot5-base-text2mol",
        vocab_size=None,  # 821
        hidden_size=768,
        num_attention_heads=12,
        num_hidden_layers=12,
    ):
        super().__init__()

        config = AutoConfig.from_pretrained(config_name)
        config.is_decoder = True
        config.add_cross_attention = True
        config.hidden_dropout_prob = 0.1
        config.num_attention_heads = num_attention_heads
        config.num_hidden_layers = num_hidden_layers
        config.max_position_embeddings = 512
        config.layer_norm_eps = 1e-12
        config.vocab_size = vocab_size
        config.d_model = hidden_size

        self.hidden_size = hidden_size
        self.in_channels = in_channels
        self.model_channels = model_channels
        self.dropout = dropout
        self.word_embedding = nn.Embedding(vocab_size, self.in_channels)
        self.lm_head = nn.Linear(self.in_channels, vocab_size)
        self.lm_head.weight = self.word_embedding.weight

        self.caption_down_proj = nn.Sequential(
            linear(768, self.hidden_size),
            SiLU(),
            linear(self.hidden_size, self.hidden_size),
        )

        time_embed_dim = model_channels * 4  # 512
        self.time_embed = nn.Sequential(
            linear(self.model_channels, time_embed_dim),
            SiLU(),
            linear(time_embed_dim, self.hidden_size),
        )

        self.input_up_proj = nn.Sequential(
            nn.Linear(self.in_channels, self.hidden_size),
            nn.Tanh(),
            nn.Linear(self.hidden_size, self.hidden_size),
        )

        self.input_transformers = T5EncoderModel(config)
        # self.input_transformers.eval()
        # for param in self.input_transformers.parameters():
        #     param.requires_grad = False

        self.register_buffer(
            "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
        )
        self.position_embeddings = nn.Embedding(
            config.max_position_embeddings, self.hidden_size
        )

        self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.output_down_proj = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.Tanh(),
            nn.Linear(self.hidden_size, self.in_channels),
        )

    def get_embeds(self, input_ids):
        return self.word_embedding(input_ids)

    def get_embeds_with_deep(self, input_ids):
        atom, deep = input_ids
        atom = self.word_embedding(atom)
        deep = self.deep_embedding(deep)

        return torch.concat([atom, deep], dim=-1)

    def get_logits(self, hidden_repr):
        return self.lm_head(hidden_repr)

    def forward(self, x, timesteps, caption_state, caption_mask, y=None):
        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
        emb_x = self.input_up_proj(x)
        seq_length = x.size(1)
        position_ids = self.position_ids[:, :seq_length]
        emb_inputs = (
            self.position_embeddings(position_ids)
            + emb_x
            + emb.unsqueeze(1).expand(-1, seq_length, -1)
        )
        emb_inputs = self.dropout(self.LayerNorm(emb_inputs))

        caption_state = self.dropout(
            self.LayerNorm(self.caption_down_proj(caption_state))
        )

        input_trans_hidden_states = self.input_transformers.encoder(
            inputs_embeds=emb_inputs,
            encoder_hidden_states=caption_state,
            encoder_attention_mask=caption_mask,
        ).last_hidden_state
        h = self.output_down_proj(input_trans_hidden_states)
        h = h.type(x.dtype)
        return h