Yanisadel commited on
Commit
95d75ca
·
verified ·
1 Parent(s): 228437f

Upload SegmentBorzoi

Browse files
Files changed (2) hide show
  1. config.json +7 -0
  2. segment_borzoi.py +272 -0
config.json CHANGED
@@ -2,6 +2,11 @@
2
  "architectures": [
3
  "SegmentBorzoi"
4
  ],
 
 
 
 
 
5
  "dim_divisible_by": 32,
6
  "embed_dim": 1536,
7
  "features": [
@@ -21,6 +26,8 @@
21
  "promoter_Tissue_invariant"
22
  ],
23
  "model_type": "segment_borzoi",
 
 
24
  "torch_dtype": "float32",
25
  "transformers_version": "4.41.1"
26
  }
 
2
  "architectures": [
3
  "SegmentBorzoi"
4
  ],
5
+ "attention_dim_key": 64,
6
+ "auto_map": {
7
+ "AutoConfig": "segment_borzoi.SegmentBorzoiConfig",
8
+ "AutoModel": "segment_borzoi.SegmentBorzoi"
9
+ },
10
  "dim_divisible_by": 32,
11
  "embed_dim": 1536,
12
  "features": [
 
26
  "promoter_Tissue_invariant"
27
  ],
28
  "model_type": "segment_borzoi",
29
+ "num_attention_heads": 8,
30
+ "num_rel_pos_features": 32,
31
  "torch_dtype": "float32",
32
  "transformers_version": "4.41.1"
33
  }
segment_borzoi.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+ import borzoi_pytorch
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+ from torch import einsum
8
+ from transformers import PretrainedConfig, PreTrainedModel
9
+
10
+ from genomics_research.segmentnt.layers.torch.segmentation_head import TorchUNetHead
11
+
12
+ FEATURES = [
13
+ "protein_coding_gene",
14
+ "lncRNA",
15
+ "exon",
16
+ "intron",
17
+ "splice_donor",
18
+ "splice_acceptor",
19
+ "5UTR",
20
+ "3UTR",
21
+ "CTCF-bound",
22
+ "polyA_signal",
23
+ "enhancer_Tissue_specific",
24
+ "enhancer_Tissue_invariant",
25
+ "promoter_Tissue_specific",
26
+ "promoter_Tissue_invariant",
27
+ ]
28
+
29
+
30
+ class SegmentBorzoiConfig(PretrainedConfig):
31
+ model_type = "segment_borzoi"
32
+
33
+ def __init__(
34
+ self,
35
+ features: List[str] = FEATURES,
36
+ embed_dim: int = 1536,
37
+ dim_divisible_by: int = 32,
38
+ attention_dim_key: int = 64,
39
+ num_attention_heads: int = 8,
40
+ num_rel_pos_features: int = 32,
41
+ **kwargs: Dict[str, Any],
42
+ ):
43
+ self.features = features
44
+ self.embed_dim = embed_dim
45
+ self.dim_divisible_by = dim_divisible_by
46
+ self.attention_dim_key = attention_dim_key
47
+ self.num_attention_heads = num_attention_heads
48
+ self.num_rel_pos_features = num_rel_pos_features
49
+
50
+ super().__init__(**kwargs)
51
+
52
+
53
+ class SegmentBorzoi(PreTrainedModel):
54
+ config_class = SegmentBorzoiConfig
55
+
56
+ def __init__(self, config: SegmentBorzoiConfig):
57
+ super().__init__(config=config)
58
+ borzoi = borzoi_pytorch.Borzoi.from_pretrained("johahi/borzoi-replicate-0")
59
+
60
+ # Stem
61
+ self.stem = borzoi.conv_dna
62
+
63
+ # Conv tower
64
+ self.res_tower = borzoi.res_tower
65
+ self.unet1 = borzoi.unet1
66
+ self._max_pool = borzoi._max_pool
67
+
68
+ # Transformer tower
69
+ self.transformer = borzoi.transformer
70
+
71
+ # UNet convolution layers
72
+ self.horizontal_conv1 = borzoi.horizontal_conv1
73
+ self.horizontal_conv0 = borzoi.horizontal_conv0
74
+ self.upsampling_unet1 = borzoi.upsampling_unet1
75
+ self.upsampling_unet0 = borzoi.upsampling_unet0
76
+ self.separable1 = borzoi.separable1
77
+ self.separable0 = borzoi.separable0
78
+
79
+ # Target length crop
80
+ self.crop = borzoi.crop
81
+
82
+ # Final convolution block
83
+ self.final_joined_convs = borzoi.final_joined_convs
84
+
85
+ self.unet_head = TorchUNetHead(
86
+ features=config.features,
87
+ embed_dimension=config.embed_dim,
88
+ nucl_per_token=config.dim_divisible_by,
89
+ remove_cls_token=False,
90
+ )
91
+
92
+ # Correct transformer
93
+ for layer in self.transformer:
94
+ layer[0].fn[1] = BorzoiAttentionLayer(
95
+ config.embed_dim,
96
+ heads=config.num_attention_heads,
97
+ dim_key=config.attention_dim_key,
98
+ dim_value=config.embed_dim // config.num_attention_heads,
99
+ dropout=0.05,
100
+ pos_dropout=0.01,
101
+ num_rel_pos_features=config.num_rel_pos_features,
102
+ )
103
+
104
+ # Correct bias in separable layers
105
+ self.separable1.conv_layer[1].bias = None
106
+ self.separable0.conv_layer[1].bias = None
107
+
108
+ def forward(self, x):
109
+ # Stem
110
+ x = x.transpose(1, 2)
111
+ x = self.stem(x)
112
+
113
+ # Conv tower
114
+ x_unet0 = self.res_tower(x)
115
+ x_unet1 = self.unet1(x_unet0)
116
+ x = self._max_pool(x_unet1)
117
+
118
+ # Transformer tower
119
+ x = x.permute(0, 2, 1)
120
+ x = self.transformer(x)
121
+ x = x.permute(0, 2, 1)
122
+
123
+ # UNet conv
124
+ x_unet1 = self.horizontal_conv1(x_unet1)
125
+ x_unet0 = self.horizontal_conv0(x_unet0)
126
+
127
+ # UNet upsampling and separable convolutions
128
+ x = self.upsampling_unet1(x)
129
+ x += x_unet1
130
+ x = self.separable1(x)
131
+ x = self.upsampling_unet0(x)
132
+ x += x_unet0
133
+ x = self.separable0(x)
134
+
135
+ # Target length crop
136
+ x = self.crop(x.permute(0, 2, 1))
137
+ x = x.permute(0, 2, 1)
138
+
139
+ # Final convolution block
140
+ x = self.final_joined_convs(x)
141
+
142
+ x = self.unet_head(x)
143
+
144
+ return x
145
+
146
+
147
+ # Define custom attention layer for PyTorch model because Attention layer from the
148
+ # imported model is not the same (the positional embeddings are not the same)
149
+ def _prepend_dims(tensor: torch.Tensor, num_dims: int) -> torch.Tensor:
150
+ """Prepends dimensions to match the required shape."""
151
+ for _ in range(num_dims - tensor.dim()):
152
+ tensor = tensor.unsqueeze(0)
153
+ return tensor
154
+
155
+
156
+ def get_positional_features_central_mask_borzoi(
157
+ positions: torch.Tensor, feature_size: int, seq_length: int
158
+ ) -> torch.Tensor:
159
+ """Positional features using a central mask (allow only central features)."""
160
+ pow_rate = torch.exp(torch.log(torch.tensor(seq_length + 1.0)) / feature_size)
161
+ center_widths = torch.pow(pow_rate, torch.arange(1, feature_size + 1).float()) - 1
162
+ center_widths = _prepend_dims(center_widths, positions.ndim)
163
+ outputs = (center_widths > torch.abs(positions).unsqueeze(-1)).float()
164
+ return outputs
165
+
166
+
167
+ def get_positional_embed_borzoi(seq_len: int, feature_size: int) -> torch.Tensor:
168
+ """
169
+ Compute positional embedding for Borzoi. Note that it is different than the one
170
+ used in Enformer.
171
+ """
172
+ distances = torch.arange(-seq_len + 1, seq_len)
173
+
174
+ num_components = 2
175
+
176
+ if (feature_size % num_components) != 0:
177
+ raise ValueError(
178
+ f"feature size is not divisible by number of components ({num_components})"
179
+ )
180
+
181
+ num_basis_per_class = feature_size // num_components
182
+
183
+ embeddings = []
184
+
185
+ embeddings.append(
186
+ get_positional_features_central_mask_borzoi(
187
+ distances, num_basis_per_class, seq_len
188
+ )
189
+ )
190
+
191
+ embeddings = torch.cat(embeddings, dim=-1)
192
+ embeddings = torch.cat(
193
+ (embeddings, torch.sign(distances).unsqueeze(-1) * embeddings), dim=-1
194
+ )
195
+ return embeddings
196
+
197
+
198
+ def relative_shift(x: torch.Tensor) -> torch.Tensor:
199
+ to_pad = torch.zeros_like(x[..., :1])
200
+ x = torch.cat((to_pad, x), dim=-1)
201
+ _, h, t1, t2 = x.shape
202
+ x = x.reshape(-1, h, t2, t1)
203
+ x = x[:, :, 1:, :]
204
+ x = x.reshape(-1, h, t1, t2 - 1)
205
+ return x[..., : ((t2 + 1) // 2)]
206
+
207
+
208
+ class BorzoiAttentionLayer(nn.Module):
209
+ def __init__(
210
+ self,
211
+ dim,
212
+ *,
213
+ num_rel_pos_features,
214
+ heads=8,
215
+ dim_key=64,
216
+ dim_value=64,
217
+ dropout=0.0,
218
+ pos_dropout=0.0,
219
+ ):
220
+ super().__init__()
221
+ self.scale = dim_key**-0.5
222
+ self.heads = heads
223
+
224
+ self.to_q = nn.Linear(dim, dim_key * heads, bias=False)
225
+ self.to_k = nn.Linear(dim, dim_key * heads, bias=False)
226
+ self.to_v = nn.Linear(dim, dim_value * heads, bias=False)
227
+
228
+ self.to_out = nn.Linear(dim_value * heads, dim)
229
+ nn.init.zeros_(self.to_out.weight)
230
+ nn.init.zeros_(self.to_out.bias)
231
+
232
+ self.num_rel_pos_features = num_rel_pos_features
233
+
234
+ self.to_rel_k = nn.Linear(num_rel_pos_features, dim_key * heads, bias=False)
235
+ self.rel_content_bias = nn.Parameter(torch.randn(1, heads, 1, dim_key))
236
+ self.rel_pos_bias = nn.Parameter(torch.randn(1, heads, 1, dim_key))
237
+
238
+ # dropouts
239
+
240
+ self.pos_dropout = nn.Dropout(pos_dropout)
241
+ self.attn_dropout = nn.Dropout(dropout)
242
+
243
+ def forward(self, x):
244
+ n, h = x.shape[-2], self.heads
245
+
246
+ q = self.to_q(x)
247
+ k = self.to_k(x)
248
+ v = self.to_v(x)
249
+
250
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
251
+
252
+ q = q * self.scale
253
+
254
+ content_logits = einsum(
255
+ "b h i d, b h j d -> b h i j", q + self.rel_content_bias, k
256
+ )
257
+
258
+ positions = get_positional_embed_borzoi(n, self.num_rel_pos_features)
259
+ positions = self.pos_dropout(positions)
260
+ rel_k = self.to_rel_k(positions)
261
+
262
+ rel_k = rearrange(rel_k, "n (h d) -> h n d", h=h)
263
+ rel_logits = einsum("b h i d, h j d -> b h i j", q + self.rel_pos_bias, rel_k)
264
+ rel_logits = relative_shift(rel_logits)
265
+
266
+ logits = content_logits + rel_logits
267
+ attn = logits.softmax(dim=-1)
268
+ attn = self.attn_dropout(attn)
269
+
270
+ out = einsum("b h i j, b h j d -> b h i d", attn, v)
271
+ out = rearrange(out, "b h n d -> b n (h d)")
272
+ return self.to_out(out)