probablybots
commited on
Update README.md
Browse files
README.md
CHANGED
@@ -41,7 +41,7 @@ mgen test --model SequenceClassification --model.backbone aido_rna_1b600m --data
|
|
41 |
#### Embedding
|
42 |
```python
|
43 |
from modelgenerator.tasks import Embed
|
44 |
-
model = Embed.from_config({"model.backbone": "
|
45 |
collated_batch = model.collate({"sequences": ["ACGT", "AGCT"]})
|
46 |
embedding = model(collated_batch)
|
47 |
print(embedding.shape)
|
@@ -51,7 +51,7 @@ print(embedding)
|
|
51 |
```python
|
52 |
import torch
|
53 |
from modelgenerator.tasks import SequenceClassification
|
54 |
-
model = SequenceClassification.from_config({"model.backbone": "
|
55 |
collated_batch = model.collate({"sequences": ["ACGT", "AGCT"]})
|
56 |
logits = model(collated_batch)
|
57 |
print(logits)
|
@@ -61,7 +61,7 @@ print(torch.argmax(logits, dim=-1))
|
|
61 |
```python
|
62 |
import torch
|
63 |
from modelgenerator.tasks import TokenClassification
|
64 |
-
model = TokenClassification.from_config({"model.backbone": "
|
65 |
collated_batch = model.collate({"sequences": ["ACGT", "AGCT"]})
|
66 |
logits = model(collated_batch)
|
67 |
print(logits)
|
@@ -70,10 +70,11 @@ print(torch.argmax(logits, dim=-1))
|
|
70 |
#### Sequence-level Regression
|
71 |
```python
|
72 |
from modelgenerator.tasks import SequenceRegression
|
73 |
-
model = SequenceRegression.from_config({"model.backbone": "
|
74 |
collated_batch = model.collate({"sequences": ["ACGT", "AGCT"]})
|
75 |
logits = model(collated_batch)
|
76 |
print(logits)
|
|
|
77 |
|
78 |
### Get RNA sequence embedding
|
79 |
```python
|
|
|
41 |
#### Embedding
|
42 |
```python
|
43 |
from modelgenerator.tasks import Embed
|
44 |
+
model = Embed.from_config({"model.backbone": "aido_rna_1b600m"}).eval()
|
45 |
collated_batch = model.collate({"sequences": ["ACGT", "AGCT"]})
|
46 |
embedding = model(collated_batch)
|
47 |
print(embedding.shape)
|
|
|
51 |
```python
|
52 |
import torch
|
53 |
from modelgenerator.tasks import SequenceClassification
|
54 |
+
model = SequenceClassification.from_config({"model.backbone": "aido_rna_1b600m", "model.n_classes": 2}).eval()
|
55 |
collated_batch = model.collate({"sequences": ["ACGT", "AGCT"]})
|
56 |
logits = model(collated_batch)
|
57 |
print(logits)
|
|
|
61 |
```python
|
62 |
import torch
|
63 |
from modelgenerator.tasks import TokenClassification
|
64 |
+
model = TokenClassification.from_config({"model.backbone": "aido_rna_1b600m", "model.n_classes": 3}).eval()
|
65 |
collated_batch = model.collate({"sequences": ["ACGT", "AGCT"]})
|
66 |
logits = model(collated_batch)
|
67 |
print(logits)
|
|
|
70 |
#### Sequence-level Regression
|
71 |
```python
|
72 |
from modelgenerator.tasks import SequenceRegression
|
73 |
+
model = SequenceRegression.from_config({"model.backbone": "aido_rna_1b600m"}).eval()
|
74 |
collated_batch = model.collate({"sequences": ["ACGT", "AGCT"]})
|
75 |
logits = model(collated_batch)
|
76 |
print(logits)
|
77 |
+
```
|
78 |
|
79 |
### Get RNA sequence embedding
|
80 |
```python
|