probablybots commited on
Commit
119a760
·
verified ·
1 Parent(s): c3b4f26

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +5 -4
README.md CHANGED
@@ -41,7 +41,7 @@ mgen test --model SequenceClassification --model.backbone aido_rna_1b600m --data
41
  #### Embedding
42
  ```python
43
  from modelgenerator.tasks import Embed
44
- model = Embed.from_config({"model.backbone": "aido_dna_7b"}).eval()
45
  collated_batch = model.collate({"sequences": ["ACGT", "AGCT"]})
46
  embedding = model(collated_batch)
47
  print(embedding.shape)
@@ -51,7 +51,7 @@ print(embedding)
51
  ```python
52
  import torch
53
  from modelgenerator.tasks import SequenceClassification
54
- model = SequenceClassification.from_config({"model.backbone": "aido_dna_7b", "model.n_classes": 2}).eval()
55
  collated_batch = model.collate({"sequences": ["ACGT", "AGCT"]})
56
  logits = model(collated_batch)
57
  print(logits)
@@ -61,7 +61,7 @@ print(torch.argmax(logits, dim=-1))
61
  ```python
62
  import torch
63
  from modelgenerator.tasks import TokenClassification
64
- model = TokenClassification.from_config({"model.backbone": "aido_dna_7b", "model.n_classes": 3}).eval()
65
  collated_batch = model.collate({"sequences": ["ACGT", "AGCT"]})
66
  logits = model(collated_batch)
67
  print(logits)
@@ -70,10 +70,11 @@ print(torch.argmax(logits, dim=-1))
70
  #### Sequence-level Regression
71
  ```python
72
  from modelgenerator.tasks import SequenceRegression
73
- model = SequenceRegression.from_config({"model.backbone": "aido_dna_7b"}).eval()
74
  collated_batch = model.collate({"sequences": ["ACGT", "AGCT"]})
75
  logits = model(collated_batch)
76
  print(logits)
 
77
 
78
  ### Get RNA sequence embedding
79
  ```python
 
41
  #### Embedding
42
  ```python
43
  from modelgenerator.tasks import Embed
44
+ model = Embed.from_config({"model.backbone": "aido_rna_1b600m"}).eval()
45
  collated_batch = model.collate({"sequences": ["ACGT", "AGCT"]})
46
  embedding = model(collated_batch)
47
  print(embedding.shape)
 
51
  ```python
52
  import torch
53
  from modelgenerator.tasks import SequenceClassification
54
+ model = SequenceClassification.from_config({"model.backbone": "aido_rna_1b600m", "model.n_classes": 2}).eval()
55
  collated_batch = model.collate({"sequences": ["ACGT", "AGCT"]})
56
  logits = model(collated_batch)
57
  print(logits)
 
61
  ```python
62
  import torch
63
  from modelgenerator.tasks import TokenClassification
64
+ model = TokenClassification.from_config({"model.backbone": "aido_rna_1b600m", "model.n_classes": 3}).eval()
65
  collated_batch = model.collate({"sequences": ["ACGT", "AGCT"]})
66
  logits = model(collated_batch)
67
  print(logits)
 
70
  #### Sequence-level Regression
71
  ```python
72
  from modelgenerator.tasks import SequenceRegression
73
+ model = SequenceRegression.from_config({"model.backbone": "aido_rna_1b600m"}).eval()
74
  collated_batch = model.collate({"sequences": ["ACGT", "AGCT"]})
75
  logits = model(collated_batch)
76
  print(logits)
77
+ ```
78
 
79
  ### Get RNA sequence embedding
80
  ```python