AlienChen commited on
Commit
df3a13c
·
verified ·
1 Parent(s): 490329d

Delete models/enhancer_models.py

Browse files
Files changed (1) hide show
  1. models/enhancer_models.py +0 -215
models/enhancer_models.py DELETED
@@ -1,215 +0,0 @@
1
- from torch import nn
2
- import torch
3
- import numpy as np
4
- import torch.nn.functional as F
5
- import copy
6
- import pdb
7
-
8
- class GaussianFourierProjection(nn.Module):
9
- """
10
- Gaussian random features for encoding time steps.
11
- """
12
-
13
- def __init__(self, embed_dim, scale=30.):
14
- super().__init__()
15
- # Randomly sample weights during initialization. These weights are fixed
16
- # during optimization and are not trainable.
17
- self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
18
-
19
- def forward(self, x):
20
- x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
21
- return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
22
-
23
- class Dense(nn.Module):
24
- """
25
- A fully connected layer that reshapes outputs to feature maps.
26
- """
27
-
28
- def __init__(self, input_dim, output_dim):
29
- super().__init__()
30
- self.dense = nn.Linear(input_dim, output_dim)
31
-
32
- def forward(self, x):
33
- return self.dense(x)[...]
34
-
35
- class Swish(nn.Module):
36
- def __init__(self):
37
- super().__init__()
38
-
39
- def forward(self, x):
40
- return torch.sigmoid(x) * x
41
-
42
- class CNNModel(nn.Module):
43
- """A time-dependent score-based model built upon U-Net architecture."""
44
-
45
- def __init__(self, alphabet_size=4, embed_dim=256, hidden_dim=256):
46
- """
47
- Args:
48
- embed_dim (int): Dimensionality of the token and time embeddings.
49
- """
50
- super().__init__()
51
- self.alphabet_size = alphabet_size
52
-
53
- self.token_embedding = nn.Embedding(self.alphabet_size, embed_dim)
54
-
55
- self.time_embed = nn.Sequential(
56
- GaussianFourierProjection(embed_dim=embed_dim),
57
- nn.Linear(embed_dim, embed_dim)
58
- )
59
-
60
- self.swish = Swish()
61
-
62
- n = hidden_dim
63
-
64
- self.linear = nn.Conv1d(embed_dim, n, kernel_size=9, padding=4)
65
-
66
- self.blocks = nn.ModuleList([
67
- nn.Conv1d(n, n, kernel_size=9, padding=4),
68
- nn.Conv1d(n, n, kernel_size=9, padding=4),
69
- nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
70
- nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
71
- nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256),
72
- # nn.Conv1d(n, n, kernel_size=9, padding=4),
73
- # nn.Conv1d(n, n, kernel_size=9, padding=4),
74
- # nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
75
- # nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
76
- # nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256),
77
- # nn.Conv1d(n, n, kernel_size=9, padding=4),
78
- # nn.Conv1d(n, n, kernel_size=9, padding=4),
79
- # nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
80
- # nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
81
- # nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256),
82
- # nn.Conv1d(n, n, kernel_size=9, padding=4),
83
- # nn.Conv1d(n, n, kernel_size=9, padding=4),
84
- # nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
85
- # nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
86
- # nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256)
87
- ])
88
-
89
- self.denses = nn.ModuleList([Dense(embed_dim, n) for _ in range(5)])
90
- self.norms = nn.ModuleList([nn.GroupNorm(1, n) for _ in range(5)])
91
-
92
- self.final = nn.Sequential(
93
- nn.Conv1d(n, n, kernel_size=1),
94
- nn.GELU(),
95
- nn.Conv1d(n, self.alphabet_size, kernel_size=1)
96
- )
97
-
98
-
99
- def forward(self, x, t):
100
- """
101
- Args:
102
- x: Tensor of shape (B, L) containing DNA token indices.
103
- t: Tensor of shape (B,) containing the time steps.
104
- Returns:
105
- out: Tensor of shape (B, L, 4) with output logits for each DNA base.
106
- """
107
- x = self.token_embedding(x) # (B, L) -> (B, L, embed_dim)
108
-
109
- time_embed = self.swish(self.time_embed(t)) # (B, embed_dim)
110
-
111
- out = x.permute(0, 2, 1) # (B, L, embed_dim) -> (B, embed_dim, L)
112
- out = self.swish(self.linear(out)) # (B, n, L)
113
-
114
- # Process through convolutional blocks, adding time conditioning via dense layers.
115
- for block, dense, norm in zip(self.blocks, self.denses, self.norms):
116
- # dense(embed) gives (B, n); unsqueeze to (B, n, 1) for broadcasting.
117
- h = self.swish(block(norm(out + dense(time_embed)[:, :, None])))
118
- # Residual connection if shapes match.
119
- if h.shape == out.shape:
120
- out = h + out
121
- else:
122
- out = h
123
-
124
- out = self.final(out) # (B, 4, L)
125
- out = out.permute(0, 2, 1) # (B, L, 4)
126
-
127
- # Normalization
128
- out = out - out.mean(dim=-1, keepdim=True)
129
- return out
130
-
131
-
132
- class MLPModel(nn.Module):
133
- def __init__(
134
- self, input_dim: int = 128, time_dim: int = 1, hidden_dim=128, length=500):
135
- super().__init__()
136
- self.input_dim = input_dim
137
- self.time_dim = time_dim
138
- self.hidden_dim = hidden_dim
139
-
140
- self.time_embedding = nn.Linear(1, time_dim)
141
- self.token_embedding = torch.nn.Embedding(self.input_dim, hidden_dim)
142
-
143
- self.swish = Swish()
144
-
145
- self.main = nn.Sequential(
146
- self.swish,
147
- nn.Linear(hidden_dim * length + time_dim, hidden_dim),
148
- self.swish,
149
- nn.Linear(hidden_dim, hidden_dim),
150
- self.swish,
151
- nn.Linear(hidden_dim, hidden_dim),
152
- self.swish,
153
- nn.Linear(hidden_dim, self.input_dim * length),
154
- )
155
-
156
- def forward(self, x, t):
157
- '''
158
- x shape (B,L)
159
- t shape (B,)
160
- '''
161
- t = self.time_embedding(t.unsqueeze(-1))
162
- x = self.token_embedding(x)
163
-
164
- B, N, d = x.shape
165
- x = x.reshape(B, N * d)
166
-
167
- h = torch.cat([x, t], dim=1)
168
- h = self.main(h)
169
-
170
- h = h.reshape(B, N, self.input_dim)
171
-
172
- return h
173
-
174
- class DirichletCNNModel(nn.Module):
175
- def __init__(self, args, alphabet_size):
176
- super().__init__()
177
- self.alphabet_size = alphabet_size
178
- self.args = args
179
- expanded_simplex_input = args.cls_expanded_simplex and (args.mode == 'dirichlet' or args.mode == 'riemannian')
180
- inp_size = self.alphabet_size * (2 if expanded_simplex_input else 1)
181
- self.linear = nn.Conv1d(inp_size, args.hidden_dim, kernel_size=9, padding=4)
182
- self.time_embedder = nn.Sequential(GaussianFourierProjection(embed_dim= args.hidden_dim),nn.Linear(args.hidden_dim, args.hidden_dim))
183
-
184
- self.num_layers = 5 * args.num_cnn_stacks
185
- self.convs = [nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, padding=4),
186
- nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, padding=4),
187
- nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=4, padding=16),
188
- nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=16, padding=64),
189
- nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=64, padding=256)]
190
- self.convs = nn.ModuleList([copy.deepcopy(layer) for layer in self.convs for i in range(args.num_cnn_stacks)])
191
- self.time_layers = nn.ModuleList([Dense(args.hidden_dim, args.hidden_dim) for _ in range(self.num_layers)])
192
- self.norms = nn.ModuleList([nn.LayerNorm(args.hidden_dim) for _ in range(self.num_layers)])
193
- self.final_conv = nn.Sequential(nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=1),
194
- nn.ReLU(),
195
- nn.Conv1d(args.hidden_dim, self.alphabet_size, kernel_size=1))
196
- self.dropout = nn.Dropout(args.dropout)
197
-
198
- def forward(self, seq, t):
199
- time_emb = F.relu(self.time_embedder(t))
200
- feat = seq.permute(0, 2, 1)
201
- feat = F.relu(self.linear(feat))
202
-
203
- for i in range(self.num_layers):
204
- h = self.dropout(feat.clone())
205
- if not self.args.clean_data:
206
- h = h + self.time_layers[i](time_emb)[:, :, None]
207
- h = self.norms[i]((h).permute(0, 2, 1))
208
- h = F.relu(self.convs[i](h.permute(0, 2, 1)))
209
- if h.shape == feat.shape:
210
- feat = h + feat
211
- else:
212
- feat = h
213
- feat = self.final_conv(feat)
214
- feat = feat.permute(0, 2, 1)
215
- return feat