qilowoq commited on
Commit
d3b48ad
1 Parent(s): eeed8d4

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +53 -6
README.md CHANGED
@@ -12,20 +12,19 @@ tags:
12
  - OAS
13
  ---
14
 
15
- # AbLang model for light chains
16
 
17
  This is a huggingface version of AbLang: A language model for antibodies. It was introduced in
18
  [this paper](https://doi.org/10.1101/2022.01.20.477061) and first released in
19
  [this repository](https://github.com/oxpig/AbLang). This model is trained on uppercase amino acids: it only works with capital letter amino acids.
20
 
21
-
22
- # Intended uses & limitations
23
 
24
  The model could be used for protein feature extraction or to be fine-tuned on downstream tasks (TBA).
25
 
26
  ### How to use
27
 
28
- Since this is a custom model, you need to install additional dependencies:
29
 
30
  ```python
31
  pip install ablang
@@ -47,7 +46,55 @@ model_output = model(**encoded_input)
47
  Sequence embeddings can be produced as follows:
48
 
49
  ```python
50
- seq_embs = model_output.last_hidden_state[:, 0, :]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  ```
52
 
53
  ### Citation
@@ -59,4 +106,4 @@ seq_embs = model_output.last_hidden_state[:, 0, :]
59
  doi={https://doi.org/10.1101/2022.01.20.477061},
60
  year={2022}
61
  }
62
- ```
 
12
  - OAS
13
  ---
14
 
15
+ ### AbLang model for light chains
16
 
17
  This is a huggingface version of AbLang: A language model for antibodies. It was introduced in
18
  [this paper](https://doi.org/10.1101/2022.01.20.477061) and first released in
19
  [this repository](https://github.com/oxpig/AbLang). This model is trained on uppercase amino acids: it only works with capital letter amino acids.
20
 
21
+ ### Intended uses & limitations
 
22
 
23
  The model could be used for protein feature extraction or to be fine-tuned on downstream tasks (TBA).
24
 
25
  ### How to use
26
 
27
+ Here is how to use this model to get the features of a given antibody sequence in PyTorch:
28
 
29
  ```python
30
  pip install ablang
 
46
  Sequence embeddings can be produced as follows:
47
 
48
  ```python
49
+ def get_sequence_embeddings(encoded_input, model_output):
50
+ mask = encoded_input['attention_mask'].float()
51
+ d = {k: v for k, v in torch.nonzero(mask).cpu().numpy()} # dict of sep tokens
52
+ # make sep token invisible
53
+ for i in d:
54
+ mask[i, d[i]] = 0
55
+ mask[:, 0] = 0.0 # make cls token invisible
56
+ mask = mask.unsqueeze(-1).expand(model_output.last_hidden_state.size())
57
+ sum_embeddings = torch.sum(model_output.last_hidden_state * mask, 1)
58
+ sum_mask = torch.clamp(mask.sum(1), min=1e-9)
59
+ return sum_embeddings / sum_mask
60
+
61
+ seq_embeds = get_sequence_embeddings(encoded_input, model_output)
62
+ ```
63
+
64
+ ### Fine-tune
65
+
66
+ To save memory we recomend using [LoRA](https://doi.org/10.48550/arXiv.2106.09685):
67
+
68
+ ```python
69
+ pip install git+https://github.com/huggingface/peft.git
70
+ pip install loralib
71
+ ```
72
+
73
+ LoRA greatly reduces the number of trainable parameters and performs on-par or better than fine-tuning full model.
74
+
75
+ ```python
76
+ from peft import LoraConfig, get_peft_model
77
+
78
+ def apply_lora_bert(model):
79
+ config = LoraConfig(
80
+ r=8, lora_alpha=32,
81
+ lora_dropout=0.3,
82
+ target_modules=['query', 'value']
83
+ )
84
+ for param in model.parameters():
85
+ param.requires_grad = False # freeze the model - train adapters later
86
+ if param.ndim == 1:
87
+ # cast the small parameters (e.g. layernorm) to fp32 for stability
88
+ param.data = param.data.to(torch.float32)
89
+ model.gradient_checkpointing_enable() # reduce number of stored activations
90
+ model.enable_input_require_grads()
91
+ model = get_peft_model(model, config)
92
+ return model
93
+
94
+ model = apply_lora_bert(model)
95
+
96
+ model.print_trainable_parameters()
97
+ # trainable params: 294912 || all params: 85493760 || trainable%: 0.3449514911965505
98
  ```
99
 
100
  ### Citation
 
106
  doi={https://doi.org/10.1101/2022.01.20.477061},
107
  year={2022}
108
  }
109
+ ```