etrop commited on
Commit
0db4b4b
1 Parent(s): 91c65e2

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +63 -6
README.md CHANGED
@@ -1,10 +1,67 @@
1
  ## Model Overview
2
  AgroNt is a DNA language model trained on primarily edible plant genomes. More specifically, AgroNT uses the transformer architecture with self-attention and a masked language modeling
3
- objective to leverage highly available genotype data from 48 different plant speices. AgroNt contains 1 billion parameters and has a context window of 1000 tokens. AgroNt uses a non-overlapping
4
- 6-mer tokenizer to convert genomic nucletoide sequences to tokens. As a result the 1000 tokens correspond to approximately 6000 base pairs.
5
 
6
 
7
- ## Using the Model from HF
8
- '''python
9
- Will update once it it public
10
- '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ## Model Overview
2
  AgroNt is a DNA language model trained on primarily edible plant genomes. More specifically, AgroNT uses the transformer architecture with self-attention and a masked language modeling
3
+ objective to leverage highly available genotype data from 48 different plant speices to learn general representations of nucleotide sequences. AgroNT contains 1 billion parameters and has a context window of 1000 tokens.
4
+ AgroNt uses a non-overlapping 6-mer tokenizer to convert genomic nucletoide sequences to tokens. As a result the 1024 tokens correspond to approximately 6144 base pairs.
5
 
6
 
7
+ ## How to use
8
+ ```python
9
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
10
+ import torch
11
+
12
+
13
+ model_name = 'agro-nt'
14
+
15
+ # fetch model and tokenizer from InstaDeep's hf repo
16
+ agro_nt_model = AutoModelForMaskedLM.from_pretrained(f'InstaDeepAI/{model_name}')
17
+ agro_nt_tokenizer = AutoTokenizer.from_pretrained(f'InstaDeepAI/{model_name}')
18
+
19
+ print(f"Loaded the {model_name} model with {agro_nt_model.num_parameters()} parameters and corresponding tokenizer.")
20
+
21
+ # example sequence and tokenization
22
+ sequences = ['ATATACGGCCGNC']
23
+
24
+ batch_tokens = agro_nt_tokenizer(sequences)['input_ids']
25
+ print(f"Tokenzied sequence: {agro_nt_tokenizer.batch_decode(batch_tokens)}")
26
+
27
+ torch_batch_tokens = torch.tensor(batch_tokens)
28
+ attention_mask = torch_batch_tokens != agro_nt_tokenizer.pad_token_id
29
+
30
+ # inference
31
+ outs = agro_nt_model(
32
+ torch_batch_tokens,
33
+ attention_mask=attention_mask,
34
+ encoder_attention_mask=attention_mask,
35
+ output_hidden_states=True
36
+ )
37
+
38
+ # get the final layer embeddings and language model head logits
39
+ embeddings = outs['hidden_states'][-1].detach().numpy()
40
+ logits = outs['logits'].detach().numpy()
41
+ ```
42
+
43
+
44
+ ## Pre-training
45
+
46
+ #### Data
47
+ Our pre-training dataset was built from (mostly) edible plants reference genomes contained in the Ensembl Plants database.
48
+ The dataset consists of approximately 10.5 million genomic sequences across 48 different species.
49
+
50
+ #### Processing
51
+ All reference genomes for each specie were assembled into a single fasta file. In this fasta file, all nucleotides other than A, T, C, G were replaced by N. We used a tokenizer to convert strings of letters into sequences of tokens.
52
+ The tokenizer's alphabet consisted of the $4^6 = 4096$ possible 6-mer combinations obtained by combining A, T, C, G, as well as five additional tokens
53
+ representing standalone A, T, C, G, and N. It also included three special tokens: the padding [PAD], masking [MASK], and the beginning of sequence
54
+ (also called class; [CLS]) token. This resulted in a vocabulary of 4104 tokens. To tokenize an input sequence, the tokenizer started with a class token and
55
+ then converted the sequence from left to right, matching 6-mer tokens when possible, or using the standalone tokens when necessary (for instance, when the letter
56
+ N was present or if the sequence length was not a multiple of 6).
57
+
58
+ #### Training
59
+ The MLM objective was used to pre-train AgroNT in a self-supervised manner. In a self-supervised learning setting annotations (supervision) for each sequence
60
+ are not needed as we can mask some proportion of the sequence and use the information contained in the unmasked portion of the sequence to predict the masked locations.
61
+ This allows us to leverage the vast amount of unlabeled genomic sequencing data available. Specifically, 15\% of the tokens in the input sequence are selected to be
62
+ augmented with 80\% being replaced with a mask token, 10\% randomly replaced by another token from the vocabulary, and the final 10\% maintaining the same token.
63
+ The tokenized sequence is passed through the model and a cross entropy loss is computed for the masked tokens. Pre-training was carried out with a sequence length of 1024 tokens
64
+ and an effective batch size of 1.5M tokens for 315k update steps, resulting in the model training on a total of 472.5B tokens.
65
+
66
+ #### Hardware
67
+ Model pre-training was carried out using Google TPU-V4 accelerators, specifically a TPU v4-1024 containing 512 devices. We trained for a total of approx. four days.