File size: 10,079 Bytes
acb7e86
 
 
818ca8a
acb7e86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
818ca8a
acb7e86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
818ca8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
acb7e86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
from transformers import GPT2Config, AutoTokenizer, GPT2Config
from transformers import PretrainedConfig, PreTrainedModel
import transformers
from typing import Optional, Tuple, Callable, List
import torch
import torch.nn as nn
from transformers.modeling_utils import PreTrainedModel, PretrainedConfig
from .utils import CABlock, _GPT2LMHeadModel
from .configuration_prot2text import Prot2TextConfig
from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList


class Prot2TextModel(PreTrainedModel):
    config_class = Prot2TextConfig
    _keys_to_ignore_on_load_missing = [r"transformer"]
    base_model_prefix = "decoder"
    def __init__(self, config):
        super().__init__(config)

        self.gpt_config = GPT2Config.from_dict(config.gpt_config)

        # define the GPT2 decoder
        self.decoder = _GPT2LMHeadModel(self.gpt_config)

        # if using ESM to encode protein's sequence, define the ESM layer, the Projection layer and the fusion layer
        if config.esm:
            self.esm_config = PretrainedConfig.from_dict(config.esm_config)
            self.esm = transformers.EsmModel(self.esm_config)
            self.to_embedding = nn.Linear(self.esm_config.hidden_size, self.gpt_config.n_embd)
            if config.cross_esm_graph and config.rgcn:
                self.h = nn.ModuleList([CABlock(self.gpt_config,  layer_idx=i) for i in range(4)])
                self.ln_f = nn.LayerNorm(self.gpt_config.n_embd, eps=self.gpt_config.layer_norm_epsilon)
            
        self.config = config
        
        
    def get_encoder(self):
        return self.encoder
        
    def get_decoder(self):
        return self.decoder

    def get_input_embeddings(self):
        if hasattr(self, "transformer"):
            return self.transformer.wte
        return self.decoder.transformer.wte
    
    def warm_up(self, gpt_model=None, esm_model=None):
        if esm_model is not None:
            self.esm = transformers.EsmModel.from_pretrained(esm_model)
        if gpt_model is not None:    
            self.decoder = _GPT2LMHeadModel.from_pretrained(gpt_model, add_cross_attention=True, use_cache=False)
            self.decoder.resize_token_embeddings(self.gpt_config.vocab_size)
            self.decoder.config = self.gpt_config
                
        
    def forward(self,
                encoder_input_ids: Optional[torch.LongTensor] = None,
                edge_index: Optional[torch.LongTensor] = None,
                batch: Optional[torch.LongTensor] = None,
                x: Optional[torch.FloatTensor] = None,
                edge_type: Optional[torch.LongTensor] = None,
                decoder_input_ids: Optional[torch.LongTensor] = None,
                past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
                past_key_values_graph_esm: Optional[Tuple[Tuple[torch.Tensor]]] = None,
                decoder_attention_mask: Optional[torch.FloatTensor] = None,
                attention_mask: Optional[torch.FloatTensor] = None,
                token_type_ids: Optional[torch.LongTensor] = None,
                position_ids: Optional[torch.LongTensor] = None,
                head_mask: Optional[torch.FloatTensor] = None,
                inputs_embeds: Optional[torch.FloatTensor] = None,
                encoder_hidden_states: Optional[torch.Tensor] = None,
                encoder_attention_mask: Optional[torch.FloatTensor] = None,
                labels: Optional[torch.LongTensor] = None,
                use_cache: Optional[bool] = None,
                output_attentions: Optional[bool] = None,
                output_hidden_states: Optional[bool] = None,
                return_dict: Optional[bool] = None,
                get_graph_emb: Optional[bool] = False,
                **delete_args,
            ):
        use_cache = use_cache if use_cache is not None else self.gpt_config.use_cache
        return_dict = return_dict if return_dict is not None else self.gpt_config.use_return_dict
        
        
        if decoder_input_ids is not None and len(decoder_input_ids.size()) == 3:
            decoder_input_ids = decoder_input_ids.squeeze(0) 

        if self.config.esm:
            if self.config.prot2text_version=='1.0':
                if encoder_input_ids.size()[1] != 1021:
                    raise ValueError("For this version of the model you need to PAD/Truncate the amino acid sequence for the ESM model to 1021")
            
            esm_emb = self.esm(input_ids=encoder_input_ids, attention_mask=attention_mask, return_dict=return_dict).last_hidden_state
            esm_emb = self.to_embedding(esm_emb)
            graph_emb = esm_emb
        else:
            attention_mask = None
        if self.config.prot2text_version=='1.0':
            attention_mask = None
        if get_graph_emb:
            return graph_emb
    
        transformer_outputs = self.decoder(input_ids=decoder_input_ids,
                                            past_key_values=past_key_values,
                                            attention_mask=decoder_attention_mask,
                                            token_type_ids=token_type_ids,
                                            position_ids=position_ids,
                                            head_mask=head_mask,
                                            inputs_embeds=inputs_embeds,
                                            encoder_hidden_states=graph_emb,
                                            encoder_attention_mask=attention_mask,
                                            labels=labels,
                                            use_cache=use_cache,
                                            output_attentions=output_attentions,
                                            output_hidden_states=output_hidden_states,
                                            return_dict=return_dict,
                                            )
        
        return transformer_outputs
    
    @torch.no_grad()    
    def generate_protein_description(self,
                                    protein_sequence=None,
                                    tokenizer=None,
                                    device='cpu'
                                     ):
        
        if self.config.esm and not self.config.rgcn and protein_sequence==None:
            raise ValueError(
                "The model you are trying to use is based only on protein sequence, please provide an amino-acid protein_sequence"
            )
        if self.config.esm:
            esmtokenizer = AutoTokenizer.from_pretrained(self.config.esm_model_name)
            
        
        seq = esmtokenizer([protein_sequence], add_special_tokens=True, truncation=True, max_length=1021, padding='max_length', return_tensors="pt")
        inputs={}
        inputs['encoder_input_ids'] = seq['input_ids']
        inputs['attention_mask'] = seq['attention_mask']
        inputs['decoder_input_ids'] = inputs['encoder_input_ids'][:,0:1].clone()
        inputs['decoder_input_ids'][:,0] = tokenizer.bos_token_id
        
        self.to(device)
        inputs = {k: v.to(device=device, non_blocking=True) if hasattr(v, 'to') else v for k, v in inputs.items()}
        encoder_state = dict()
        encoder_state['hidden_states'] = self(**inputs, get_graph_emb=True, output_attentions=True)
        generated = tokenizer.batch_decode(self.decoder.generate(input_ids=inputs['decoder_input_ids'], encoder_outputs=encoder_state, use_cache=True), skip_special_tokens=True)
        
        return generated[0].replace('<|stop_token|>', '').replace('<|graph_token|>', '')
    
    @torch.no_grad()
    def generate(self,
                inputs: Optional[torch.Tensor] = None,
                generation_config: Optional[GenerationConfig] = None,
                logits_processor: Optional[LogitsProcessorList] = None,
                stopping_criteria: Optional[StoppingCriteriaList] = None,
                prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
                synced_gpus: Optional[bool] = None,
                assistant_model: Optional["PreTrainedModel"] = None,
                streamer: Optional["BaseStreamer"] = None,
                **kwargs,
            ):
        encoder_state = self(**kwargs, get_graph_emb=True)
        input_ids = kwargs['decoder_input_ids']
        attention_mask = kwargs['decoder_attention_mask']
        kwargs['encoder_attention_mask'] = kwargs['attention_mask']
        if not self.config.cross_esm_graph and self.config.rgcn and self.config.esm:
            t_add = torch.ones((kwargs['encoder_attention_mask'].size(0), 1)).to(kwargs['encoder_attention_mask'].get_device())
            kwargs['encoder_attention_mask'] = torch.cat((t_add, kwargs['encoder_attention_mask']), dim=1) 
        for key in ['edge_index', 'edge_type', 'x', 'encoder_input_ids', 'decoder_input_ids', 'decoder_attention_mask', 'batch', 'attention_mask', 'max_length',
                    '_num_nodes', 'node_id', 'name', 'sequence', 'distance_matrix', 'distance', 'coordinates', 'ptr', 'num_nodes',]:
            if key in kwargs.keys():
                kwargs.pop(key)
        return self.decoder.generate(input_ids=input_ids,
                                     generation_config=generation_config,
                                     logits_processor=logits_processor,
                                     stopping_criteria=stopping_criteria,
                                     prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
                                     synced_gpus=synced_gpus,
                                     assistant_model=assistant_model,
                                     streamer=streamer,
                                     encoder_outputs={'hidden_states': encoder_state, 'attentions':0},
                                     **kwargs
                                     )