update
Browse files- README.md +3 -3
- model.safetensors +1 -1
- modeling_omnigenome.py +251 -63
README.md
CHANGED
@@ -6,7 +6,7 @@ language:
|
|
6 |
|
7 |
tags:
|
8 |
- Genomic-Language-Modeling
|
9 |
-
-
|
10 |
---
|
11 |
|
12 |
# Multi-species Foundation Model for Universal RNA and DNA Downstream Tasks
|
@@ -15,13 +15,13 @@ tags:
|
|
15 |
We are keep updating the checkpoints, the current checkpoint is trained for 0.85 epoch.
|
16 |
|
17 |
## Training Examples
|
18 |
-
Refer to GitHub [https://github.com/yangheng95/
|
19 |
|
20 |
## Usage
|
21 |
This model is available for replacing genomic foundation models such as CDSBERT, Nucleotide Transformers, DNABERT2, etc.
|
22 |
```
|
23 |
from transformers import AutoModel
|
24 |
-
model = AutoModel.from_pretrained("yangheng/
|
25 |
```
|
26 |
|
27 |
## Subtasks
|
|
|
6 |
|
7 |
tags:
|
8 |
- Genomic-Language-Modeling
|
9 |
+
- OmniGenome Foundation Model
|
10 |
---
|
11 |
|
12 |
# Multi-species Foundation Model for Universal RNA and DNA Downstream Tasks
|
|
|
15 |
We are keep updating the checkpoints, the current checkpoint is trained for 0.85 epoch.
|
16 |
|
17 |
## Training Examples
|
18 |
+
Refer to GitHub [https://github.com/yangheng95/OmniGenome](https://github.com/yangheng95/OmniGenome)
|
19 |
|
20 |
## Usage
|
21 |
This model is available for replacing genomic foundation models such as CDSBERT, Nucleotide Transformers, DNABERT2, etc.
|
22 |
```
|
23 |
from transformers import AutoModel
|
24 |
+
model = AutoModel.from_pretrained("yangheng/OmniGenome-52M", trust_remote_code=True)
|
25 |
```
|
26 |
|
27 |
## Subtasks
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 210828112
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2300e9ae1743dac0e51fe56eec44718d64b408c349e5f0bc98c567d339fe7938
|
3 |
size 210828112
|
modeling_omnigenome.py
CHANGED
@@ -15,8 +15,11 @@
|
|
15 |
""" PyTorch OmniGenome model."""
|
16 |
|
17 |
import math
|
|
|
|
|
18 |
from typing import List, Optional, Tuple, Union
|
19 |
|
|
|
20 |
import torch
|
21 |
import torch.utils.checkpoint
|
22 |
from torch import nn
|
@@ -1117,7 +1120,7 @@ class OmniGenomeForMaskedLM(OmniGenomePreTrainedModel):
|
|
1117 |
|
1118 |
self.OmniGenome = OmniGenomeModel(config, add_pooling_layer=False)
|
1119 |
self.lm_head = OmniGenomeLMHead(config)
|
1120 |
-
|
1121 |
|
1122 |
def get_output_embeddings(self):
|
1123 |
return self.lm_head.decoder
|
@@ -1236,8 +1239,9 @@ class OmniGenomeForSequenceClassification(OmniGenomePreTrainedModel):
|
|
1236 |
self.num_labels = config.num_labels
|
1237 |
self.config = config
|
1238 |
self.OmniGenome = OmniGenomeModel(config, add_pooling_layer=False)
|
|
|
1239 |
self.classifier = OmniGenomeClassificationHead(config)
|
1240 |
-
|
1241 |
|
1242 |
@add_start_docstrings_to_model_forward(
|
1243 |
OmniGenome_INPUTS_DOCSTRING.format("batch_size, sequence_length")
|
@@ -1279,8 +1283,10 @@ class OmniGenomeForSequenceClassification(OmniGenomePreTrainedModel):
|
|
1279 |
output_hidden_states=output_hidden_states,
|
1280 |
return_dict=return_dict,
|
1281 |
)
|
1282 |
-
|
1283 |
-
|
|
|
|
|
1284 |
|
1285 |
loss = None
|
1286 |
if labels is not None:
|
@@ -1338,9 +1344,8 @@ class OmniGenomeForTokenClassification(OmniGenomePreTrainedModel):
|
|
1338 |
self.OmniGenome = OmniGenomeModel(config, add_pooling_layer=False)
|
1339 |
self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
1340 |
self.classifier = torch.nn.Linear(self.config.hidden_size, self.num_labels)
|
1341 |
-
self.
|
1342 |
-
self.
|
1343 |
-
# self.init_weights()
|
1344 |
|
1345 |
@add_start_docstrings_to_model_forward(
|
1346 |
OmniGenome_INPUTS_DOCSTRING.format("batch_size, sequence_length")
|
@@ -1366,12 +1371,12 @@ class OmniGenomeForTokenClassification(OmniGenomePreTrainedModel):
|
|
1366 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1367 |
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
1368 |
"""
|
1369 |
-
|
1370 |
return_dict = (
|
1371 |
return_dict if return_dict is not None else self.config.use_return_dict
|
1372 |
)
|
1373 |
-
|
1374 |
-
|
1375 |
input_ids,
|
1376 |
attention_mask=attention_mask,
|
1377 |
position_ids=position_ids,
|
@@ -1381,17 +1386,11 @@ class OmniGenomeForTokenClassification(OmniGenomePreTrainedModel):
|
|
1381 |
output_hidden_states=output_hidden_states,
|
1382 |
return_dict=return_dict,
|
1383 |
)
|
1384 |
-
try:
|
1385 |
-
last_hidden_state = mlm_outputs[0]
|
1386 |
-
last_hidden_state = self.dense(last_hidden_state)
|
1387 |
-
except:
|
1388 |
-
last_hidden_state = mlm_outputs.hidden_states[-1]
|
1389 |
-
last_hidden_state = self.dense(last_hidden_state)
|
1390 |
|
|
|
|
|
1391 |
logits = self.classifier(last_hidden_state)
|
1392 |
-
logits =
|
1393 |
-
logits = self.activation(logits)
|
1394 |
-
logits = self.dropout(logits)
|
1395 |
|
1396 |
loss = None
|
1397 |
if labels is not None:
|
@@ -1399,14 +1398,14 @@ class OmniGenomeForTokenClassification(OmniGenomePreTrainedModel):
|
|
1399 |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
1400 |
|
1401 |
if not return_dict:
|
1402 |
-
output = (logits,) +
|
1403 |
return ((loss,) + output) if loss is not None else output
|
1404 |
|
1405 |
return TokenClassifierOutput(
|
1406 |
loss=loss,
|
1407 |
logits=logits,
|
1408 |
-
hidden_states=
|
1409 |
-
attentions=
|
1410 |
)
|
1411 |
|
1412 |
@staticmethod
|
@@ -1432,7 +1431,7 @@ class OmniGenomeForTokenClassification(OmniGenomePreTrainedModel):
|
|
1432 |
|
1433 |
return structure
|
1434 |
|
1435 |
-
def
|
1436 |
self,
|
1437 |
input_ids: Optional[torch.LongTensor] = None,
|
1438 |
attention_mask: Optional[torch.Tensor] = None,
|
@@ -1457,18 +1456,26 @@ class OmniGenomeForTokenClassification(OmniGenomePreTrainedModel):
|
|
1457 |
|
1458 |
@add_start_docstrings(
|
1459 |
"""
|
1460 |
-
|
|
|
1461 |
""",
|
1462 |
OmniGenome_START_DOCSTRING,
|
1463 |
)
|
1464 |
-
class
|
1465 |
def __init__(self, config):
|
1466 |
super().__init__(config)
|
1467 |
self.num_labels = config.num_labels
|
1468 |
-
self.OmniGenome =
|
|
|
1469 |
self.num_generation = config.num_generation
|
1470 |
self.num_population = config.num_population
|
1471 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1472 |
|
1473 |
@add_start_docstrings_to_model_forward(
|
1474 |
OmniGenome_INPUTS_DOCSTRING.format("batch_size, sequence_length")
|
@@ -1494,43 +1501,224 @@ class OmniGenomeMaskedLMForRNADesign(OmniGenomePreTrainedModel):
|
|
1494 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1495 |
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
1496 |
"""
|
1497 |
-
|
1498 |
-
|
1499 |
-
)
|
1500 |
-
|
1501 |
-
outputs = self.OmniGenome(
|
1502 |
-
input_ids,
|
1503 |
-
attention_mask=attention_mask,
|
1504 |
-
position_ids=position_ids,
|
1505 |
-
head_mask=head_mask,
|
1506 |
-
inputs_embeds=inputs_embeds,
|
1507 |
-
output_attentions=output_attentions,
|
1508 |
-
output_hidden_states=output_hidden_states,
|
1509 |
-
return_dict=return_dict,
|
1510 |
-
)
|
1511 |
-
|
1512 |
-
sequence_output = outputs[0]
|
1513 |
-
|
1514 |
-
sequence_output = self.dropout(sequence_output)
|
1515 |
-
logits = self.classifier(sequence_output)
|
1516 |
|
1517 |
-
|
1518 |
-
|
1519 |
-
|
1520 |
-
|
1521 |
-
|
1522 |
-
|
1523 |
-
|
1524 |
-
|
1525 |
-
|
1526 |
-
|
1527 |
-
|
1528 |
-
|
1529 |
-
|
1530 |
-
|
1531 |
-
|
1532 |
-
|
1533 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1534 |
|
1535 |
|
1536 |
# Copied from transformers.models.esm.modeling_esm.EsmClassificationHead with Esm->OmniGenome
|
|
|
15 |
""" PyTorch OmniGenome model."""
|
16 |
|
17 |
import math
|
18 |
+
import random
|
19 |
+
import warnings
|
20 |
from typing import List, Optional, Tuple, Union
|
21 |
|
22 |
+
import numpy as np
|
23 |
import torch
|
24 |
import torch.utils.checkpoint
|
25 |
from torch import nn
|
|
|
1120 |
|
1121 |
self.OmniGenome = OmniGenomeModel(config, add_pooling_layer=False)
|
1122 |
self.lm_head = OmniGenomeLMHead(config)
|
1123 |
+
self.init_weights()
|
1124 |
|
1125 |
def get_output_embeddings(self):
|
1126 |
return self.lm_head.decoder
|
|
|
1239 |
self.num_labels = config.num_labels
|
1240 |
self.config = config
|
1241 |
self.OmniGenome = OmniGenomeModel(config, add_pooling_layer=False)
|
1242 |
+
self.pooler = OmniGenomePooler(config)
|
1243 |
self.classifier = OmniGenomeClassificationHead(config)
|
1244 |
+
self.init_weights()
|
1245 |
|
1246 |
@add_start_docstrings_to_model_forward(
|
1247 |
OmniGenome_INPUTS_DOCSTRING.format("batch_size, sequence_length")
|
|
|
1283 |
output_hidden_states=output_hidden_states,
|
1284 |
return_dict=return_dict,
|
1285 |
)
|
1286 |
+
last_hidden_state = outputs[0]
|
1287 |
+
last_hidden_state = self.dense(last_hidden_state)
|
1288 |
+
pooled_output = self.pooler(last_hidden_state)
|
1289 |
+
logits = self.classifier(pooled_output)
|
1290 |
|
1291 |
loss = None
|
1292 |
if labels is not None:
|
|
|
1344 |
self.OmniGenome = OmniGenomeModel(config, add_pooling_layer=False)
|
1345 |
self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
1346 |
self.classifier = torch.nn.Linear(self.config.hidden_size, self.num_labels)
|
1347 |
+
self.softmax = nn.Softmax(dim=-1)
|
1348 |
+
self.init_weights()
|
|
|
1349 |
|
1350 |
@add_start_docstrings_to_model_forward(
|
1351 |
OmniGenome_INPUTS_DOCSTRING.format("batch_size, sequence_length")
|
|
|
1371 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1372 |
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
1373 |
"""
|
1374 |
+
|
1375 |
return_dict = (
|
1376 |
return_dict if return_dict is not None else self.config.use_return_dict
|
1377 |
)
|
1378 |
+
|
1379 |
+
outputs = self.OmniGenome(
|
1380 |
input_ids,
|
1381 |
attention_mask=attention_mask,
|
1382 |
position_ids=position_ids,
|
|
|
1386 |
output_hidden_states=output_hidden_states,
|
1387 |
return_dict=return_dict,
|
1388 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
1389 |
|
1390 |
+
last_hidden_state = outputs[0]
|
1391 |
+
last_hidden_state = self.dense(last_hidden_state)
|
1392 |
logits = self.classifier(last_hidden_state)
|
1393 |
+
logits = self.softmax(logits)
|
|
|
|
|
1394 |
|
1395 |
loss = None
|
1396 |
if labels is not None:
|
|
|
1398 |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
1399 |
|
1400 |
if not return_dict:
|
1401 |
+
output = (logits,) + outputs[2:]
|
1402 |
return ((loss,) + output) if loss is not None else output
|
1403 |
|
1404 |
return TokenClassifierOutput(
|
1405 |
loss=loss,
|
1406 |
logits=logits,
|
1407 |
+
hidden_states=outputs.hidden_states,
|
1408 |
+
attentions=outputs.attentions,
|
1409 |
)
|
1410 |
|
1411 |
@staticmethod
|
|
|
1431 |
|
1432 |
return structure
|
1433 |
|
1434 |
+
def predict_rna_structure(
|
1435 |
self,
|
1436 |
input_ids: Optional[torch.LongTensor] = None,
|
1437 |
attention_mask: Optional[torch.Tensor] = None,
|
|
|
1456 |
|
1457 |
@add_start_docstrings(
|
1458 |
"""
|
1459 |
+
This is not a standard Seq2Seq model. Instead, this model is designed for RNA design tasks.
|
1460 |
+
This is the OmniGenome Model with a simple genetic algorithm based RNA design head on top.
|
1461 |
""",
|
1462 |
OmniGenome_START_DOCSTRING,
|
1463 |
)
|
1464 |
+
class OmniGenomeModelForSeq2SeqLM(OmniGenomePreTrainedModel):
|
1465 |
def __init__(self, config):
|
1466 |
super().__init__(config)
|
1467 |
self.num_labels = config.num_labels
|
1468 |
+
self.OmniGenome = OmniGenomeModel(config, add_pooling_layer=False)
|
1469 |
+
self.lm_head = OmniGenomeLMHead(config)
|
1470 |
self.num_generation = config.num_generation
|
1471 |
self.num_population = config.num_population
|
1472 |
+
self.init_weights()
|
1473 |
+
|
1474 |
+
self.tokenizer = None
|
1475 |
+
self.predict_structure = None
|
1476 |
+
|
1477 |
+
warnings.warn(f"This model {self.__class__.__name__} is not a real Seq2Seq model. "
|
1478 |
+
f"Instead, this model is designed for RNA design tasks")
|
1479 |
|
1480 |
@add_start_docstrings_to_model_forward(
|
1481 |
OmniGenome_INPUTS_DOCSTRING.format("batch_size, sequence_length")
|
|
|
1501 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1502 |
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
1503 |
"""
|
1504 |
+
raise NotImplementedError("This model is not designed for standard Seq2Seq tasks. "
|
1505 |
+
"Use model.rna_sequence_design() for RNA sequences design instead.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1506 |
|
1507 |
+
def rna_sequence_design(
|
1508 |
+
self,
|
1509 |
+
structure: str,
|
1510 |
+
predict_structure_func=None,
|
1511 |
+
**kwargs
|
1512 |
+
) -> List[str]:
|
1513 |
+
"""
|
1514 |
+
Assemble the RNA sequence given the reference sequence structure
|
1515 |
+
"""
|
1516 |
+
if self.tokenizer is None:
|
1517 |
+
tokenizer = kwargs.get("tokenizer", None)
|
1518 |
+
if tokenizer is None:
|
1519 |
+
from transformers import AutoTokenizer
|
1520 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.config.name_or_path)
|
1521 |
+
else:
|
1522 |
+
self.tokenizer = tokenizer
|
1523 |
+
|
1524 |
+
candidates = self.genetic_algorithm_for_rna_design(structure, predict_structure_func=None, **kwargs)
|
1525 |
+
|
1526 |
+
return candidates
|
1527 |
+
|
1528 |
+
def genetic_algorithm_for_rna_design(self, structure, predict_structure_func=None, **kwargs):
|
1529 |
+
if predict_structure_func is None:
|
1530 |
+
import ViennaRNA
|
1531 |
+
|
1532 |
+
def predict_structure(sequence):
|
1533 |
+
return ViennaRNA.fold(sequence)[0]
|
1534 |
+
|
1535 |
+
predict_structure_func = predict_structure
|
1536 |
+
|
1537 |
+
self.predict_structure = predict_structure_func
|
1538 |
+
mutation_ratio = kwargs.get("mutation_ratio", 0.2)
|
1539 |
+
num_population = kwargs.get("num_population", self.num_population)
|
1540 |
+
num_generation = kwargs.get("num_generation", self.num_generation)
|
1541 |
+
import tqdm
|
1542 |
+
population = self.init_population(structure, num_population)
|
1543 |
+
population = self.mlm_mutate(population, structure, mutation_ratio=mutation_ratio)
|
1544 |
+
for generation_id in tqdm.tqdm(range(num_generation), desc="Designing RNA Sequence"):
|
1545 |
+
population_fitness = self.sequence_fitness(population, structure)[:num_population]
|
1546 |
+
population = sorted(zip(population, population_fitness), key=lambda x: x[1])[:num_population]
|
1547 |
+
population = [x[0] for x in population]
|
1548 |
+
next_generation = population # Elitism
|
1549 |
+
next_generation += self.crossover(population, structure)
|
1550 |
+
next_generation += self.mlm_mutate(next_generation, structure, mutation_ratio)
|
1551 |
+
fitness_values = self.sequence_fitness(next_generation, structure)
|
1552 |
+
next_generation = sorted(zip(next_generation, fitness_values), key=lambda x: x[1])
|
1553 |
+
|
1554 |
+
candidate_sequences = []
|
1555 |
+
for sequence, fitness in next_generation:
|
1556 |
+
if fitness == 0:
|
1557 |
+
candidate_sequences.append(sequence)
|
1558 |
+
else:
|
1559 |
+
break
|
1560 |
+
if candidate_sequences:
|
1561 |
+
return candidate_sequences
|
1562 |
+
print(f"Generation {generation_id}: {next_generation[0][0]} with fitness {next_generation[0][1]}")
|
1563 |
+
population = [x[0] for x in next_generation[:num_population]]
|
1564 |
+
|
1565 |
+
return []
|
1566 |
+
|
1567 |
+
def init_population(self, structure, num_population):
|
1568 |
+
# Initialize lists to store population data and inputs for masked language model
|
1569 |
+
population = []
|
1570 |
+
mlm_inputs = []
|
1571 |
+
# Iterate over the number of individuals in the population
|
1572 |
+
for _ in range(num_population): # Changed from self.num_population to num_population
|
1573 |
+
# Create a sequence by randomly choosing nucleotides or a mask token for each position in the structure
|
1574 |
+
masked_sequence = [
|
1575 |
+
random.choice(["A", "G", "C", "T", "<mask>"])
|
1576 |
+
for _ in range(len(structure))
|
1577 |
+
]
|
1578 |
+
masked_sequence_str = "".join(masked_sequence)
|
1579 |
+
mlm_inputs.append(f"{masked_sequence_str}<eos>{''.join(structure)}")
|
1580 |
+
|
1581 |
+
# Call a function to predict outputs using the masked language model
|
1582 |
+
outputs = self.mlm_predict(mlm_inputs, structure)
|
1583 |
+
|
1584 |
+
# Decode the mlm outputs and construct the initial population
|
1585 |
+
for i in range(len(outputs)):
|
1586 |
+
sequence = self.tokenizer.convert_ids_to_tokens(outputs[i].tolist())
|
1587 |
+
fixed_sequence = [
|
1588 |
+
x if x in "AGCT" else random.choice(["G", "C"])
|
1589 |
+
for x, y in zip(sequence, list(mlm_inputs[i].replace('<mask>', '$')))
|
1590 |
+
]
|
1591 |
+
population.append("".join(fixed_sequence))
|
1592 |
+
|
1593 |
+
return population
|
1594 |
+
|
1595 |
+
def mlm_mutate(self, population, structure, mutation_ratio=0.2):
|
1596 |
+
def mutate(sequence, mutation_rate=0.2):
|
1597 |
+
sequence = np.array(list(sequence), dtype=np.str_)
|
1598 |
+
probability_matrix = np.full(sequence.shape, mutation_rate)
|
1599 |
+
masked_indices = np.random.rand(*sequence.shape) < probability_matrix
|
1600 |
+
sequence[masked_indices] = "$"
|
1601 |
+
mut_seq = "".join(sequence.tolist()).replace("$", "<mask>")
|
1602 |
+
return mut_seq
|
1603 |
+
def mutate_with_spans_mask(sequence, mutation_rate=0.2):
|
1604 |
+
sequence = np.array(list(sequence), dtype=np.str_)
|
1605 |
+
length = len(sequence)
|
1606 |
+
num_mutations = int(mutation_rate * length) # Total number of mutations is based on mutation rate
|
1607 |
+
# Decide the average span length; we assume mutation spans about 20% of the total mutations length
|
1608 |
+
average_span_length = random.randint(1, max(1, int(length * mutation_rate / 10)))
|
1609 |
+
# Initialize mutation points
|
1610 |
+
mutation_points = np.random.choice(length, num_mutations, replace=False) # Start points for mutations
|
1611 |
+
# Create the mask
|
1612 |
+
mask = np.zeros(length, dtype=bool)
|
1613 |
+
for start in mutation_points:
|
1614 |
+
end = start + average_span_length
|
1615 |
+
if end > length:
|
1616 |
+
end = length
|
1617 |
+
mask[start:end] = True # Masking a span from start to end
|
1618 |
+
# Apply mask
|
1619 |
+
sequence[mask] = "<mask>"
|
1620 |
+
# Combine the masked parts with the rest of the sequence
|
1621 |
+
mutated_sequence = ''.join(sequence)
|
1622 |
+
# Since multiple consecutive '<mask>'s might occur, replace them with a single '<mask>'
|
1623 |
+
mutated_sequence = mutated_sequence.replace('<mask>' * average_span_length, '<mask>')
|
1624 |
+
return mutated_sequence
|
1625 |
+
|
1626 |
+
# Initialize lists to store population data and inputs for masked language model
|
1627 |
+
mlm_inputs = []
|
1628 |
+
masked_sequences = []
|
1629 |
+
|
1630 |
+
# Iterate over the number of individuals in the population
|
1631 |
+
for sequence in population:
|
1632 |
+
# Create a sequence by randomly choosing nucleotides or a mask token for each position in the structure
|
1633 |
+
if random.random() < 1:
|
1634 |
+
masked_sequence = mutate(sequence, mutation_ratio)
|
1635 |
+
else:
|
1636 |
+
masked_sequence = mutate_with_spans_mask(sequence, mutation_ratio)
|
1637 |
+
masked_sequences.append(masked_sequence)
|
1638 |
+
mlm_inputs.append(f"{masked_sequence}<eos>{''.join(structure)}")
|
1639 |
+
|
1640 |
+
# Call a function to predict outputs using the masked language model
|
1641 |
+
outputs = self.mlm_predict(mlm_inputs, structure)
|
1642 |
+
|
1643 |
+
mut_population = []
|
1644 |
+
|
1645 |
+
# Decode the mlm outputs and construct the initial population
|
1646 |
+
for i in range(len(outputs)):
|
1647 |
+
sequence = self.tokenizer.convert_ids_to_tokens(outputs[i].tolist())
|
1648 |
+
fixed_sequence = [
|
1649 |
+
x if x in "AGCT" else random.choice(["G", "C"])
|
1650 |
+
for x, y in zip(sequence, list(masked_sequences[i].replace('<mask>', '$')))
|
1651 |
+
]
|
1652 |
+
mut_population.append("".join(fixed_sequence))
|
1653 |
+
|
1654 |
+
return mut_population
|
1655 |
+
|
1656 |
+
def crossover(self, population, structure):
|
1657 |
+
crossover_population = []
|
1658 |
+
batch_crossover_inputs = []
|
1659 |
+
for i in range(len(population)):
|
1660 |
+
parent1, parent2 = random.choices(population, k=2)
|
1661 |
+
pos = random.randint(1, len(parent1) - 1)
|
1662 |
+
child1 = parent1[:pos] + "<mask>" * len(parent2[pos:])
|
1663 |
+
child2 = "<mask>" * len(parent1[:pos]) + parent2[pos:]
|
1664 |
+
batch_crossover_inputs.append(f"{child1}<eos>{structure}")
|
1665 |
+
batch_crossover_inputs.append(f"{child2}<eos>{structure}")
|
1666 |
+
|
1667 |
+
outputs = self.mlm_predict(batch_crossover_inputs, structure)
|
1668 |
+
|
1669 |
+
for i in range(len(outputs)):
|
1670 |
+
sequence = self.tokenizer.convert_ids_to_tokens(outputs[i].tolist())
|
1671 |
+
fixed_sequence = [
|
1672 |
+
x if x in "AGCT" else random.choice(["G", "C"])
|
1673 |
+
for x, y in zip(sequence, list(batch_crossover_inputs[i].replace('<mask>', '$')))
|
1674 |
+
]
|
1675 |
+
crossover_population.append("".join(fixed_sequence))
|
1676 |
+
|
1677 |
+
return crossover_population
|
1678 |
+
|
1679 |
+
def sequence_fitness(self, sequences, structure):
|
1680 |
+
fitness_values = []
|
1681 |
+
structures = [self.predict_structure(sequence) for sequence in sequences]
|
1682 |
+
for predicted_structure in structures:
|
1683 |
+
scores = []
|
1684 |
+
for i in range(len(predicted_structure)):
|
1685 |
+
if predicted_structure[i] == structure[i]:
|
1686 |
+
scores.append(1)
|
1687 |
+
elif (
|
1688 |
+
predicted_structure[i] == ")"
|
1689 |
+
and structure[i] == "("
|
1690 |
+
or predicted_structure[i] == "("
|
1691 |
+
and structure[i] == ")"
|
1692 |
+
):
|
1693 |
+
scores.append(-3)
|
1694 |
+
else:
|
1695 |
+
scores.append(0)
|
1696 |
+
score = 1 - sum(scores) / len(structure)
|
1697 |
+
fitness_values.append(score)
|
1698 |
+
return fitness_values
|
1699 |
+
|
1700 |
+
def mlm_predict(self, mlm_inputs, structure):
|
1701 |
+
batch_size = 4
|
1702 |
+
all_outputs = []
|
1703 |
+
from transformers import set_seed
|
1704 |
+
set_seed(random.randint(0, 99999999), deterministic=False)
|
1705 |
+
|
1706 |
+
with torch.no_grad():
|
1707 |
+
for i in range(0, len(mlm_inputs), batch_size):
|
1708 |
+
batch_mlm_inputs = self.tokenizer(
|
1709 |
+
mlm_inputs[i:i + batch_size],
|
1710 |
+
padding=True,
|
1711 |
+
max_length=len(mlm_inputs[0]) // 2,
|
1712 |
+
truncation=True,
|
1713 |
+
return_tensors="pt",
|
1714 |
+
)
|
1715 |
+
batch_mlm_inputs = batch_mlm_inputs.to(self.device)
|
1716 |
+
outputs = self.OmniGenome(**batch_mlm_inputs)[0]
|
1717 |
+
outputs = self.lm_head(outputs)
|
1718 |
+
outputs = outputs.argmax(dim=-1)
|
1719 |
+
all_outputs.append(outputs)
|
1720 |
+
outputs = torch.cat(all_outputs, dim=0)
|
1721 |
+
return outputs[:, 1:1 + len(structure)]
|
1722 |
|
1723 |
|
1724 |
# Copied from transformers.models.esm.modeling_esm.EsmClassificationHead with Esm->OmniGenome
|