probablybots commited on
Commit
bfda7c2
·
verified ·
1 Parent(s): 762e0a9

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +8 -8
README.md CHANGED
@@ -19,8 +19,8 @@ mgen test --model SequenceClassification --model.backbone aido_protein_16b_v1 --
19
  ```python
20
  from modelgenerator.tasks import Embed
21
  model = Embed.from_config({"model.backbone": "aido_protein_16b_v1"}).eval()
22
- collated_batch = model.collate({"sequences": ["HELLQ", "WRLD"]})
23
- embedding = model(collated_batch)
24
  print(embedding.shape)
25
  print(embedding)
26
  ```
@@ -29,8 +29,8 @@ print(embedding)
29
  import torch
30
  from modelgenerator.tasks import SequenceClassification
31
  model = SequenceClassification.from_config({"model.backbone": "aido_protein_16b_v1", "model.n_classes": 2}).eval()
32
- collated_batch = model.collate({"sequences": ["HELLQ", "WRLD"]})
33
- logits = model(collated_batch)
34
  print(logits)
35
  print(torch.argmax(logits, dim=-1))
36
  ```
@@ -39,8 +39,8 @@ print(torch.argmax(logits, dim=-1))
39
  import torch
40
  from modelgenerator.tasks import TokenClassification
41
  model = TokenClassification.from_config({"model.backbone": "aido_protein_16b_v1", "model.n_classes": 3}).eval()
42
- collated_batch = model.collate({"sequences": ["HELLQ", "WRLD"]})
43
- logits = model(collated_batch)
44
  print(logits)
45
  print(torch.argmax(logits, dim=-1))
46
  ```
@@ -48,8 +48,8 @@ print(torch.argmax(logits, dim=-1))
48
  ```python
49
  from modelgenerator.tasks import SequenceRegression
50
  model = SequenceRegression.from_config({"model.backbone": "aido_protein_16b_v1"}).eval()
51
- collated_batch = model.collate({"sequences": ["HELLQ", "WRLD"]})
52
- logits = model(collated_batch)
53
  print(logits)
54
  ```
55
 
 
19
  ```python
20
  from modelgenerator.tasks import Embed
21
  model = Embed.from_config({"model.backbone": "aido_protein_16b_v1"}).eval()
22
+ transformed_batch = model.transform({"sequences": ["HELLQ", "WRLD"]})
23
+ embedding = model(transformed_batch)
24
  print(embedding.shape)
25
  print(embedding)
26
  ```
 
29
  import torch
30
  from modelgenerator.tasks import SequenceClassification
31
  model = SequenceClassification.from_config({"model.backbone": "aido_protein_16b_v1", "model.n_classes": 2}).eval()
32
+ transformed_batch = model.transform({"sequences": ["HELLQ", "WRLD"]})
33
+ logits = model(transformed_batch)
34
  print(logits)
35
  print(torch.argmax(logits, dim=-1))
36
  ```
 
39
  import torch
40
  from modelgenerator.tasks import TokenClassification
41
  model = TokenClassification.from_config({"model.backbone": "aido_protein_16b_v1", "model.n_classes": 3}).eval()
42
+ transformed_batch = model.transform({"sequences": ["HELLQ", "WRLD"]})
43
+ logits = model(transformed_batch)
44
  print(logits)
45
  print(torch.argmax(logits, dim=-1))
46
  ```
 
48
  ```python
49
  from modelgenerator.tasks import SequenceRegression
50
  model = SequenceRegression.from_config({"model.backbone": "aido_protein_16b_v1"}).eval()
51
+ transformed_batch = model.transform({"sequences": ["HELLQ", "WRLD"]})
52
+ logits = model(transformed_batch)
53
  print(logits)
54
  ```
55