yangheng commited on
Commit
b0f0a2f
1 Parent(s): b1308ae
Files changed (3) hide show
  1. README.md +3 -3
  2. model.safetensors +1 -1
  3. modeling_omnigenome.py +251 -63
README.md CHANGED
@@ -6,7 +6,7 @@ language:
6
 
7
  tags:
8
  - Genomic-Language-Modeling
9
- - RNA Genomic Foundation Model
10
  ---
11
 
12
  # Multi-species Foundation Model for Universal RNA and DNA Downstream Tasks
@@ -15,13 +15,13 @@ tags:
15
  We are keep updating the checkpoints, the current checkpoint is trained for 0.85 epoch.
16
 
17
  ## Training Examples
18
- Refer to GitHub [https://github.com/yangheng95/MP-RNA](https://github.com/yangheng95/MP-RNA)
19
 
20
  ## Usage
21
  This model is available for replacing genomic foundation models such as CDSBERT, Nucleotide Transformers, DNABERT2, etc.
22
  ```
23
  from transformers import AutoModel
24
- model = AutoModel.from_pretrained("yangheng/MPRNA-52M-v1", trust_remote_code=True)
25
  ```
26
 
27
  ## Subtasks
 
6
 
7
  tags:
8
  - Genomic-Language-Modeling
9
+ - OmniGenome Foundation Model
10
  ---
11
 
12
  # Multi-species Foundation Model for Universal RNA and DNA Downstream Tasks
 
15
  We are keep updating the checkpoints, the current checkpoint is trained for 0.85 epoch.
16
 
17
  ## Training Examples
18
+ Refer to GitHub [https://github.com/yangheng95/OmniGenome](https://github.com/yangheng95/OmniGenome)
19
 
20
  ## Usage
21
  This model is available for replacing genomic foundation models such as CDSBERT, Nucleotide Transformers, DNABERT2, etc.
22
  ```
23
  from transformers import AutoModel
24
+ model = AutoModel.from_pretrained("yangheng/OmniGenome-52M", trust_remote_code=True)
25
  ```
26
 
27
  ## Subtasks
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:26d7498b3a722bfe09ca9b488b38f315f6696d3f7d6722009cbcfbf4e22480b0
3
  size 210828112
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2300e9ae1743dac0e51fe56eec44718d64b408c349e5f0bc98c567d339fe7938
3
  size 210828112
modeling_omnigenome.py CHANGED
@@ -15,8 +15,11 @@
15
  """ PyTorch OmniGenome model."""
16
 
17
  import math
 
 
18
  from typing import List, Optional, Tuple, Union
19
 
 
20
  import torch
21
  import torch.utils.checkpoint
22
  from torch import nn
@@ -1117,7 +1120,7 @@ class OmniGenomeForMaskedLM(OmniGenomePreTrainedModel):
1117
 
1118
  self.OmniGenome = OmniGenomeModel(config, add_pooling_layer=False)
1119
  self.lm_head = OmniGenomeLMHead(config)
1120
- # self.init_weights()
1121
 
1122
  def get_output_embeddings(self):
1123
  return self.lm_head.decoder
@@ -1236,8 +1239,9 @@ class OmniGenomeForSequenceClassification(OmniGenomePreTrainedModel):
1236
  self.num_labels = config.num_labels
1237
  self.config = config
1238
  self.OmniGenome = OmniGenomeModel(config, add_pooling_layer=False)
 
1239
  self.classifier = OmniGenomeClassificationHead(config)
1240
- # self.init_weights()
1241
 
1242
  @add_start_docstrings_to_model_forward(
1243
  OmniGenome_INPUTS_DOCSTRING.format("batch_size, sequence_length")
@@ -1279,8 +1283,10 @@ class OmniGenomeForSequenceClassification(OmniGenomePreTrainedModel):
1279
  output_hidden_states=output_hidden_states,
1280
  return_dict=return_dict,
1281
  )
1282
- sequence_output = outputs[0]
1283
- logits = self.classifier(sequence_output)
 
 
1284
 
1285
  loss = None
1286
  if labels is not None:
@@ -1338,9 +1344,8 @@ class OmniGenomeForTokenClassification(OmniGenomePreTrainedModel):
1338
  self.OmniGenome = OmniGenomeModel(config, add_pooling_layer=False)
1339
  self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
1340
  self.classifier = torch.nn.Linear(self.config.hidden_size, self.num_labels)
1341
- self.activation = torch.nn.Tanh()
1342
- self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
1343
- # self.init_weights()
1344
 
1345
  @add_start_docstrings_to_model_forward(
1346
  OmniGenome_INPUTS_DOCSTRING.format("batch_size, sequence_length")
@@ -1366,12 +1371,12 @@ class OmniGenomeForTokenClassification(OmniGenomePreTrainedModel):
1366
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1367
  Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1368
  """
1369
-
1370
  return_dict = (
1371
  return_dict if return_dict is not None else self.config.use_return_dict
1372
  )
1373
-
1374
- mlm_outputs = self.OmniGenome(
1375
  input_ids,
1376
  attention_mask=attention_mask,
1377
  position_ids=position_ids,
@@ -1381,17 +1386,11 @@ class OmniGenomeForTokenClassification(OmniGenomePreTrainedModel):
1381
  output_hidden_states=output_hidden_states,
1382
  return_dict=return_dict,
1383
  )
1384
- try:
1385
- last_hidden_state = mlm_outputs[0]
1386
- last_hidden_state = self.dense(last_hidden_state)
1387
- except:
1388
- last_hidden_state = mlm_outputs.hidden_states[-1]
1389
- last_hidden_state = self.dense(last_hidden_state)
1390
 
 
 
1391
  logits = self.classifier(last_hidden_state)
1392
- logits = torch.softmax(logits, dim=-1)
1393
- logits = self.activation(logits)
1394
- logits = self.dropout(logits)
1395
 
1396
  loss = None
1397
  if labels is not None:
@@ -1399,14 +1398,14 @@ class OmniGenomeForTokenClassification(OmniGenomePreTrainedModel):
1399
  loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1400
 
1401
  if not return_dict:
1402
- output = (logits,) + mlm_outputs[2:]
1403
  return ((loss,) + output) if loss is not None else output
1404
 
1405
  return TokenClassifierOutput(
1406
  loss=loss,
1407
  logits=logits,
1408
- hidden_states=mlm_outputs.hidden_states,
1409
- attentions=mlm_outputs.attentions,
1410
  )
1411
 
1412
  @staticmethod
@@ -1432,7 +1431,7 @@ class OmniGenomeForTokenClassification(OmniGenomePreTrainedModel):
1432
 
1433
  return structure
1434
 
1435
- def predict_structure(
1436
  self,
1437
  input_ids: Optional[torch.LongTensor] = None,
1438
  attention_mask: Optional[torch.Tensor] = None,
@@ -1457,18 +1456,26 @@ class OmniGenomeForTokenClassification(OmniGenomePreTrainedModel):
1457
 
1458
  @add_start_docstrings(
1459
  """
1460
- OmniGenome Model with a simple genetic algorithm based RNA design head on top.
 
1461
  """,
1462
  OmniGenome_START_DOCSTRING,
1463
  )
1464
- class OmniGenomeMaskedLMForRNADesign(OmniGenomePreTrainedModel):
1465
  def __init__(self, config):
1466
  super().__init__(config)
1467
  self.num_labels = config.num_labels
1468
- self.OmniGenome = OmniGenomeForMaskedLM(config)
 
1469
  self.num_generation = config.num_generation
1470
  self.num_population = config.num_population
1471
- # self.init_weights()
 
 
 
 
 
 
1472
 
1473
  @add_start_docstrings_to_model_forward(
1474
  OmniGenome_INPUTS_DOCSTRING.format("batch_size, sequence_length")
@@ -1494,43 +1501,224 @@ class OmniGenomeMaskedLMForRNADesign(OmniGenomePreTrainedModel):
1494
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1495
  Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1496
  """
1497
- return_dict = (
1498
- return_dict if return_dict is not None else self.config.use_return_dict
1499
- )
1500
-
1501
- outputs = self.OmniGenome(
1502
- input_ids,
1503
- attention_mask=attention_mask,
1504
- position_ids=position_ids,
1505
- head_mask=head_mask,
1506
- inputs_embeds=inputs_embeds,
1507
- output_attentions=output_attentions,
1508
- output_hidden_states=output_hidden_states,
1509
- return_dict=return_dict,
1510
- )
1511
-
1512
- sequence_output = outputs[0]
1513
-
1514
- sequence_output = self.dropout(sequence_output)
1515
- logits = self.classifier(sequence_output)
1516
 
1517
- loss = None
1518
- if labels is not None:
1519
- loss_fct = CrossEntropyLoss()
1520
-
1521
- labels = labels.to(logits.device)
1522
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1523
-
1524
- if not return_dict:
1525
- output = (logits,) + outputs[2:]
1526
- return ((loss,) + output) if loss is not None else output
1527
-
1528
- return TokenClassifierOutput(
1529
- loss=loss,
1530
- logits=logits,
1531
- hidden_states=outputs.hidden_states,
1532
- attentions=outputs.attentions,
1533
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1534
 
1535
 
1536
  # Copied from transformers.models.esm.modeling_esm.EsmClassificationHead with Esm->OmniGenome
 
15
  """ PyTorch OmniGenome model."""
16
 
17
  import math
18
+ import random
19
+ import warnings
20
  from typing import List, Optional, Tuple, Union
21
 
22
+ import numpy as np
23
  import torch
24
  import torch.utils.checkpoint
25
  from torch import nn
 
1120
 
1121
  self.OmniGenome = OmniGenomeModel(config, add_pooling_layer=False)
1122
  self.lm_head = OmniGenomeLMHead(config)
1123
+ self.init_weights()
1124
 
1125
  def get_output_embeddings(self):
1126
  return self.lm_head.decoder
 
1239
  self.num_labels = config.num_labels
1240
  self.config = config
1241
  self.OmniGenome = OmniGenomeModel(config, add_pooling_layer=False)
1242
+ self.pooler = OmniGenomePooler(config)
1243
  self.classifier = OmniGenomeClassificationHead(config)
1244
+ self.init_weights()
1245
 
1246
  @add_start_docstrings_to_model_forward(
1247
  OmniGenome_INPUTS_DOCSTRING.format("batch_size, sequence_length")
 
1283
  output_hidden_states=output_hidden_states,
1284
  return_dict=return_dict,
1285
  )
1286
+ last_hidden_state = outputs[0]
1287
+ last_hidden_state = self.dense(last_hidden_state)
1288
+ pooled_output = self.pooler(last_hidden_state)
1289
+ logits = self.classifier(pooled_output)
1290
 
1291
  loss = None
1292
  if labels is not None:
 
1344
  self.OmniGenome = OmniGenomeModel(config, add_pooling_layer=False)
1345
  self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
1346
  self.classifier = torch.nn.Linear(self.config.hidden_size, self.num_labels)
1347
+ self.softmax = nn.Softmax(dim=-1)
1348
+ self.init_weights()
 
1349
 
1350
  @add_start_docstrings_to_model_forward(
1351
  OmniGenome_INPUTS_DOCSTRING.format("batch_size, sequence_length")
 
1371
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1372
  Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1373
  """
1374
+
1375
  return_dict = (
1376
  return_dict if return_dict is not None else self.config.use_return_dict
1377
  )
1378
+
1379
+ outputs = self.OmniGenome(
1380
  input_ids,
1381
  attention_mask=attention_mask,
1382
  position_ids=position_ids,
 
1386
  output_hidden_states=output_hidden_states,
1387
  return_dict=return_dict,
1388
  )
 
 
 
 
 
 
1389
 
1390
+ last_hidden_state = outputs[0]
1391
+ last_hidden_state = self.dense(last_hidden_state)
1392
  logits = self.classifier(last_hidden_state)
1393
+ logits = self.softmax(logits)
 
 
1394
 
1395
  loss = None
1396
  if labels is not None:
 
1398
  loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1399
 
1400
  if not return_dict:
1401
+ output = (logits,) + outputs[2:]
1402
  return ((loss,) + output) if loss is not None else output
1403
 
1404
  return TokenClassifierOutput(
1405
  loss=loss,
1406
  logits=logits,
1407
+ hidden_states=outputs.hidden_states,
1408
+ attentions=outputs.attentions,
1409
  )
1410
 
1411
  @staticmethod
 
1431
 
1432
  return structure
1433
 
1434
+ def predict_rna_structure(
1435
  self,
1436
  input_ids: Optional[torch.LongTensor] = None,
1437
  attention_mask: Optional[torch.Tensor] = None,
 
1456
 
1457
  @add_start_docstrings(
1458
  """
1459
+ This is not a standard Seq2Seq model. Instead, this model is designed for RNA design tasks.
1460
+ This is the OmniGenome Model with a simple genetic algorithm based RNA design head on top.
1461
  """,
1462
  OmniGenome_START_DOCSTRING,
1463
  )
1464
+ class OmniGenomeModelForSeq2SeqLM(OmniGenomePreTrainedModel):
1465
  def __init__(self, config):
1466
  super().__init__(config)
1467
  self.num_labels = config.num_labels
1468
+ self.OmniGenome = OmniGenomeModel(config, add_pooling_layer=False)
1469
+ self.lm_head = OmniGenomeLMHead(config)
1470
  self.num_generation = config.num_generation
1471
  self.num_population = config.num_population
1472
+ self.init_weights()
1473
+
1474
+ self.tokenizer = None
1475
+ self.predict_structure = None
1476
+
1477
+ warnings.warn(f"This model {self.__class__.__name__} is not a real Seq2Seq model. "
1478
+ f"Instead, this model is designed for RNA design tasks")
1479
 
1480
  @add_start_docstrings_to_model_forward(
1481
  OmniGenome_INPUTS_DOCSTRING.format("batch_size, sequence_length")
 
1501
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1502
  Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1503
  """
1504
+ raise NotImplementedError("This model is not designed for standard Seq2Seq tasks. "
1505
+ "Use model.rna_sequence_design() for RNA sequences design instead.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1506
 
1507
+ def rna_sequence_design(
1508
+ self,
1509
+ structure: str,
1510
+ predict_structure_func=None,
1511
+ **kwargs
1512
+ ) -> List[str]:
1513
+ """
1514
+ Assemble the RNA sequence given the reference sequence structure
1515
+ """
1516
+ if self.tokenizer is None:
1517
+ tokenizer = kwargs.get("tokenizer", None)
1518
+ if tokenizer is None:
1519
+ from transformers import AutoTokenizer
1520
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.name_or_path)
1521
+ else:
1522
+ self.tokenizer = tokenizer
1523
+
1524
+ candidates = self.genetic_algorithm_for_rna_design(structure, predict_structure_func=None, **kwargs)
1525
+
1526
+ return candidates
1527
+
1528
+ def genetic_algorithm_for_rna_design(self, structure, predict_structure_func=None, **kwargs):
1529
+ if predict_structure_func is None:
1530
+ import ViennaRNA
1531
+
1532
+ def predict_structure(sequence):
1533
+ return ViennaRNA.fold(sequence)[0]
1534
+
1535
+ predict_structure_func = predict_structure
1536
+
1537
+ self.predict_structure = predict_structure_func
1538
+ mutation_ratio = kwargs.get("mutation_ratio", 0.2)
1539
+ num_population = kwargs.get("num_population", self.num_population)
1540
+ num_generation = kwargs.get("num_generation", self.num_generation)
1541
+ import tqdm
1542
+ population = self.init_population(structure, num_population)
1543
+ population = self.mlm_mutate(population, structure, mutation_ratio=mutation_ratio)
1544
+ for generation_id in tqdm.tqdm(range(num_generation), desc="Designing RNA Sequence"):
1545
+ population_fitness = self.sequence_fitness(population, structure)[:num_population]
1546
+ population = sorted(zip(population, population_fitness), key=lambda x: x[1])[:num_population]
1547
+ population = [x[0] for x in population]
1548
+ next_generation = population # Elitism
1549
+ next_generation += self.crossover(population, structure)
1550
+ next_generation += self.mlm_mutate(next_generation, structure, mutation_ratio)
1551
+ fitness_values = self.sequence_fitness(next_generation, structure)
1552
+ next_generation = sorted(zip(next_generation, fitness_values), key=lambda x: x[1])
1553
+
1554
+ candidate_sequences = []
1555
+ for sequence, fitness in next_generation:
1556
+ if fitness == 0:
1557
+ candidate_sequences.append(sequence)
1558
+ else:
1559
+ break
1560
+ if candidate_sequences:
1561
+ return candidate_sequences
1562
+ print(f"Generation {generation_id}: {next_generation[0][0]} with fitness {next_generation[0][1]}")
1563
+ population = [x[0] for x in next_generation[:num_population]]
1564
+
1565
+ return []
1566
+
1567
+ def init_population(self, structure, num_population):
1568
+ # Initialize lists to store population data and inputs for masked language model
1569
+ population = []
1570
+ mlm_inputs = []
1571
+ # Iterate over the number of individuals in the population
1572
+ for _ in range(num_population): # Changed from self.num_population to num_population
1573
+ # Create a sequence by randomly choosing nucleotides or a mask token for each position in the structure
1574
+ masked_sequence = [
1575
+ random.choice(["A", "G", "C", "T", "<mask>"])
1576
+ for _ in range(len(structure))
1577
+ ]
1578
+ masked_sequence_str = "".join(masked_sequence)
1579
+ mlm_inputs.append(f"{masked_sequence_str}<eos>{''.join(structure)}")
1580
+
1581
+ # Call a function to predict outputs using the masked language model
1582
+ outputs = self.mlm_predict(mlm_inputs, structure)
1583
+
1584
+ # Decode the mlm outputs and construct the initial population
1585
+ for i in range(len(outputs)):
1586
+ sequence = self.tokenizer.convert_ids_to_tokens(outputs[i].tolist())
1587
+ fixed_sequence = [
1588
+ x if x in "AGCT" else random.choice(["G", "C"])
1589
+ for x, y in zip(sequence, list(mlm_inputs[i].replace('<mask>', '$')))
1590
+ ]
1591
+ population.append("".join(fixed_sequence))
1592
+
1593
+ return population
1594
+
1595
+ def mlm_mutate(self, population, structure, mutation_ratio=0.2):
1596
+ def mutate(sequence, mutation_rate=0.2):
1597
+ sequence = np.array(list(sequence), dtype=np.str_)
1598
+ probability_matrix = np.full(sequence.shape, mutation_rate)
1599
+ masked_indices = np.random.rand(*sequence.shape) < probability_matrix
1600
+ sequence[masked_indices] = "$"
1601
+ mut_seq = "".join(sequence.tolist()).replace("$", "<mask>")
1602
+ return mut_seq
1603
+ def mutate_with_spans_mask(sequence, mutation_rate=0.2):
1604
+ sequence = np.array(list(sequence), dtype=np.str_)
1605
+ length = len(sequence)
1606
+ num_mutations = int(mutation_rate * length) # Total number of mutations is based on mutation rate
1607
+ # Decide the average span length; we assume mutation spans about 20% of the total mutations length
1608
+ average_span_length = random.randint(1, max(1, int(length * mutation_rate / 10)))
1609
+ # Initialize mutation points
1610
+ mutation_points = np.random.choice(length, num_mutations, replace=False) # Start points for mutations
1611
+ # Create the mask
1612
+ mask = np.zeros(length, dtype=bool)
1613
+ for start in mutation_points:
1614
+ end = start + average_span_length
1615
+ if end > length:
1616
+ end = length
1617
+ mask[start:end] = True # Masking a span from start to end
1618
+ # Apply mask
1619
+ sequence[mask] = "<mask>"
1620
+ # Combine the masked parts with the rest of the sequence
1621
+ mutated_sequence = ''.join(sequence)
1622
+ # Since multiple consecutive '<mask>'s might occur, replace them with a single '<mask>'
1623
+ mutated_sequence = mutated_sequence.replace('<mask>' * average_span_length, '<mask>')
1624
+ return mutated_sequence
1625
+
1626
+ # Initialize lists to store population data and inputs for masked language model
1627
+ mlm_inputs = []
1628
+ masked_sequences = []
1629
+
1630
+ # Iterate over the number of individuals in the population
1631
+ for sequence in population:
1632
+ # Create a sequence by randomly choosing nucleotides or a mask token for each position in the structure
1633
+ if random.random() < 1:
1634
+ masked_sequence = mutate(sequence, mutation_ratio)
1635
+ else:
1636
+ masked_sequence = mutate_with_spans_mask(sequence, mutation_ratio)
1637
+ masked_sequences.append(masked_sequence)
1638
+ mlm_inputs.append(f"{masked_sequence}<eos>{''.join(structure)}")
1639
+
1640
+ # Call a function to predict outputs using the masked language model
1641
+ outputs = self.mlm_predict(mlm_inputs, structure)
1642
+
1643
+ mut_population = []
1644
+
1645
+ # Decode the mlm outputs and construct the initial population
1646
+ for i in range(len(outputs)):
1647
+ sequence = self.tokenizer.convert_ids_to_tokens(outputs[i].tolist())
1648
+ fixed_sequence = [
1649
+ x if x in "AGCT" else random.choice(["G", "C"])
1650
+ for x, y in zip(sequence, list(masked_sequences[i].replace('<mask>', '$')))
1651
+ ]
1652
+ mut_population.append("".join(fixed_sequence))
1653
+
1654
+ return mut_population
1655
+
1656
+ def crossover(self, population, structure):
1657
+ crossover_population = []
1658
+ batch_crossover_inputs = []
1659
+ for i in range(len(population)):
1660
+ parent1, parent2 = random.choices(population, k=2)
1661
+ pos = random.randint(1, len(parent1) - 1)
1662
+ child1 = parent1[:pos] + "<mask>" * len(parent2[pos:])
1663
+ child2 = "<mask>" * len(parent1[:pos]) + parent2[pos:]
1664
+ batch_crossover_inputs.append(f"{child1}<eos>{structure}")
1665
+ batch_crossover_inputs.append(f"{child2}<eos>{structure}")
1666
+
1667
+ outputs = self.mlm_predict(batch_crossover_inputs, structure)
1668
+
1669
+ for i in range(len(outputs)):
1670
+ sequence = self.tokenizer.convert_ids_to_tokens(outputs[i].tolist())
1671
+ fixed_sequence = [
1672
+ x if x in "AGCT" else random.choice(["G", "C"])
1673
+ for x, y in zip(sequence, list(batch_crossover_inputs[i].replace('<mask>', '$')))
1674
+ ]
1675
+ crossover_population.append("".join(fixed_sequence))
1676
+
1677
+ return crossover_population
1678
+
1679
+ def sequence_fitness(self, sequences, structure):
1680
+ fitness_values = []
1681
+ structures = [self.predict_structure(sequence) for sequence in sequences]
1682
+ for predicted_structure in structures:
1683
+ scores = []
1684
+ for i in range(len(predicted_structure)):
1685
+ if predicted_structure[i] == structure[i]:
1686
+ scores.append(1)
1687
+ elif (
1688
+ predicted_structure[i] == ")"
1689
+ and structure[i] == "("
1690
+ or predicted_structure[i] == "("
1691
+ and structure[i] == ")"
1692
+ ):
1693
+ scores.append(-3)
1694
+ else:
1695
+ scores.append(0)
1696
+ score = 1 - sum(scores) / len(structure)
1697
+ fitness_values.append(score)
1698
+ return fitness_values
1699
+
1700
+ def mlm_predict(self, mlm_inputs, structure):
1701
+ batch_size = 4
1702
+ all_outputs = []
1703
+ from transformers import set_seed
1704
+ set_seed(random.randint(0, 99999999), deterministic=False)
1705
+
1706
+ with torch.no_grad():
1707
+ for i in range(0, len(mlm_inputs), batch_size):
1708
+ batch_mlm_inputs = self.tokenizer(
1709
+ mlm_inputs[i:i + batch_size],
1710
+ padding=True,
1711
+ max_length=len(mlm_inputs[0]) // 2,
1712
+ truncation=True,
1713
+ return_tensors="pt",
1714
+ )
1715
+ batch_mlm_inputs = batch_mlm_inputs.to(self.device)
1716
+ outputs = self.OmniGenome(**batch_mlm_inputs)[0]
1717
+ outputs = self.lm_head(outputs)
1718
+ outputs = outputs.argmax(dim=-1)
1719
+ all_outputs.append(outputs)
1720
+ outputs = torch.cat(all_outputs, dim=0)
1721
+ return outputs[:, 1:1 + len(structure)]
1722
 
1723
 
1724
  # Copied from transformers.models.esm.modeling_esm.EsmClassificationHead with Esm->OmniGenome