Update README.md
Browse files
README.md
CHANGED
@@ -1,10 +1,67 @@
|
|
1 |
## Model Overview
|
2 |
AgroNt is a DNA language model trained on primarily edible plant genomes. More specifically, AgroNT uses the transformer architecture with self-attention and a masked language modeling
|
3 |
-
objective to leverage highly available genotype data from 48 different plant speices.
|
4 |
-
6-mer tokenizer to convert genomic nucletoide sequences to tokens. As a result the
|
5 |
|
6 |
|
7 |
-
##
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
## Model Overview
|
2 |
AgroNt is a DNA language model trained on primarily edible plant genomes. More specifically, AgroNT uses the transformer architecture with self-attention and a masked language modeling
|
3 |
+
objective to leverage highly available genotype data from 48 different plant speices to learn general representations of nucleotide sequences. AgroNT contains 1 billion parameters and has a context window of 1000 tokens.
|
4 |
+
AgroNt uses a non-overlapping 6-mer tokenizer to convert genomic nucletoide sequences to tokens. As a result the 1024 tokens correspond to approximately 6144 base pairs.
|
5 |
|
6 |
|
7 |
+
## How to use
|
8 |
+
```python
|
9 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
10 |
+
import torch
|
11 |
+
|
12 |
+
|
13 |
+
model_name = 'agro-nt'
|
14 |
+
|
15 |
+
# fetch model and tokenizer from InstaDeep's hf repo
|
16 |
+
agro_nt_model = AutoModelForMaskedLM.from_pretrained(f'InstaDeepAI/{model_name}')
|
17 |
+
agro_nt_tokenizer = AutoTokenizer.from_pretrained(f'InstaDeepAI/{model_name}')
|
18 |
+
|
19 |
+
print(f"Loaded the {model_name} model with {agro_nt_model.num_parameters()} parameters and corresponding tokenizer.")
|
20 |
+
|
21 |
+
# example sequence and tokenization
|
22 |
+
sequences = ['ATATACGGCCGNC']
|
23 |
+
|
24 |
+
batch_tokens = agro_nt_tokenizer(sequences)['input_ids']
|
25 |
+
print(f"Tokenzied sequence: {agro_nt_tokenizer.batch_decode(batch_tokens)}")
|
26 |
+
|
27 |
+
torch_batch_tokens = torch.tensor(batch_tokens)
|
28 |
+
attention_mask = torch_batch_tokens != agro_nt_tokenizer.pad_token_id
|
29 |
+
|
30 |
+
# inference
|
31 |
+
outs = agro_nt_model(
|
32 |
+
torch_batch_tokens,
|
33 |
+
attention_mask=attention_mask,
|
34 |
+
encoder_attention_mask=attention_mask,
|
35 |
+
output_hidden_states=True
|
36 |
+
)
|
37 |
+
|
38 |
+
# get the final layer embeddings and language model head logits
|
39 |
+
embeddings = outs['hidden_states'][-1].detach().numpy()
|
40 |
+
logits = outs['logits'].detach().numpy()
|
41 |
+
```
|
42 |
+
|
43 |
+
|
44 |
+
## Pre-training
|
45 |
+
|
46 |
+
#### Data
|
47 |
+
Our pre-training dataset was built from (mostly) edible plants reference genomes contained in the Ensembl Plants database.
|
48 |
+
The dataset consists of approximately 10.5 million genomic sequences across 48 different species.
|
49 |
+
|
50 |
+
#### Processing
|
51 |
+
All reference genomes for each specie were assembled into a single fasta file. In this fasta file, all nucleotides other than A, T, C, G were replaced by N. We used a tokenizer to convert strings of letters into sequences of tokens.
|
52 |
+
The tokenizer's alphabet consisted of the $4^6 = 4096$ possible 6-mer combinations obtained by combining A, T, C, G, as well as five additional tokens
|
53 |
+
representing standalone A, T, C, G, and N. It also included three special tokens: the padding [PAD], masking [MASK], and the beginning of sequence
|
54 |
+
(also called class; [CLS]) token. This resulted in a vocabulary of 4104 tokens. To tokenize an input sequence, the tokenizer started with a class token and
|
55 |
+
then converted the sequence from left to right, matching 6-mer tokens when possible, or using the standalone tokens when necessary (for instance, when the letter
|
56 |
+
N was present or if the sequence length was not a multiple of 6).
|
57 |
+
|
58 |
+
#### Training
|
59 |
+
The MLM objective was used to pre-train AgroNT in a self-supervised manner. In a self-supervised learning setting annotations (supervision) for each sequence
|
60 |
+
are not needed as we can mask some proportion of the sequence and use the information contained in the unmasked portion of the sequence to predict the masked locations.
|
61 |
+
This allows us to leverage the vast amount of unlabeled genomic sequencing data available. Specifically, 15\% of the tokens in the input sequence are selected to be
|
62 |
+
augmented with 80\% being replaced with a mask token, 10\% randomly replaced by another token from the vocabulary, and the final 10\% maintaining the same token.
|
63 |
+
The tokenized sequence is passed through the model and a cross entropy loss is computed for the masked tokens. Pre-training was carried out with a sequence length of 1024 tokens
|
64 |
+
and an effective batch size of 1.5M tokens for 315k update steps, resulting in the model training on a total of 472.5B tokens.
|
65 |
+
|
66 |
+
#### Hardware
|
67 |
+
Model pre-training was carried out using Google TPU-V4 accelerators, specifically a TPU v4-1024 containing 512 devices. We trained for a total of approx. four days.
|