zhjohnchan commited on
Commit
b09211c
1 Parent(s): 0f2d77c

Upload 3 files

Browse files
configuration_chexagent.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The CheXagent Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from typing import Union
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.models.auto import CONFIG_MAPPING
21
+ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
22
+ from transformers.utils import logging
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class CheXagentVisionConfig(PretrainedConfig):
28
+ model_type = "chexagent_vision_model"
29
+
30
+ def __init__(
31
+ self,
32
+ hidden_size=1408,
33
+ intermediate_size=6144,
34
+ num_hidden_layers=39,
35
+ num_attention_heads=16,
36
+ image_size=224,
37
+ patch_size=14,
38
+ hidden_act="gelu",
39
+ layer_norm_eps=1e-6,
40
+ attention_dropout=0.0,
41
+ initializer_range=1e-10,
42
+ qkv_bias=True,
43
+ **kwargs,
44
+ ):
45
+ super().__init__(**kwargs)
46
+
47
+ self.hidden_size = hidden_size
48
+ self.intermediate_size = intermediate_size
49
+ self.num_hidden_layers = num_hidden_layers
50
+ self.num_attention_heads = num_attention_heads
51
+ self.patch_size = patch_size
52
+ self.image_size = image_size
53
+ self.initializer_range = initializer_range
54
+ self.attention_dropout = attention_dropout
55
+ self.layer_norm_eps = layer_norm_eps
56
+ self.hidden_act = hidden_act
57
+ self.qkv_bias = qkv_bias
58
+
59
+ @classmethod
60
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
61
+ cls._set_token_in_kwargs(kwargs)
62
+
63
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
64
+
65
+ if config_dict.get("model_type") == "chexagent":
66
+ config_dict = config_dict["vision_config"]
67
+
68
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
69
+ logger.warning(
70
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
71
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
72
+ )
73
+
74
+ return cls.from_dict(config_dict, **kwargs)
75
+
76
+
77
+ class CheXagentQFormerConfig(PretrainedConfig):
78
+ model_type = "chexagent_qformer"
79
+
80
+ def __init__(
81
+ self,
82
+ vocab_size=30522,
83
+ hidden_size=768,
84
+ num_hidden_layers=12,
85
+ num_attention_heads=12,
86
+ intermediate_size=3072,
87
+ hidden_act="gelu",
88
+ hidden_dropout_prob=0.1,
89
+ attention_probs_dropout_prob=0.1,
90
+ max_position_embeddings=512,
91
+ initializer_range=0.02,
92
+ layer_norm_eps=1e-12,
93
+ pad_token_id=0,
94
+ position_embedding_type="absolute",
95
+ cross_attention_frequency=2,
96
+ encoder_hidden_size=1408,
97
+ **kwargs,
98
+ ):
99
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
100
+
101
+ self.vocab_size = vocab_size
102
+ self.hidden_size = hidden_size
103
+ self.num_hidden_layers = num_hidden_layers
104
+ self.num_attention_heads = num_attention_heads
105
+ self.hidden_act = hidden_act
106
+ self.intermediate_size = intermediate_size
107
+ self.hidden_dropout_prob = hidden_dropout_prob
108
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
109
+ self.max_position_embeddings = max_position_embeddings
110
+ self.initializer_range = initializer_range
111
+ self.layer_norm_eps = layer_norm_eps
112
+ self.position_embedding_type = position_embedding_type
113
+ self.cross_attention_frequency = cross_attention_frequency
114
+ self.encoder_hidden_size = encoder_hidden_size
115
+
116
+ @classmethod
117
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
118
+ cls._set_token_in_kwargs(kwargs)
119
+
120
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
121
+
122
+ if config_dict.get("model_type") == "chexagent":
123
+ config_dict = config_dict["qformer_config"]
124
+
125
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
126
+ logger.warning(
127
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
128
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
129
+ )
130
+
131
+ return cls.from_dict(config_dict, **kwargs)
132
+
133
+
134
+ class CheXagentConfig(PretrainedConfig):
135
+ model_type = "chexagent"
136
+
137
+ def __init__(
138
+ self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=128,
139
+ num_max_images=2, **kwargs
140
+ ):
141
+ super().__init__(**kwargs)
142
+
143
+ if vision_config is None:
144
+ vision_config = {}
145
+
146
+ if qformer_config is None:
147
+ qformer_config = {}
148
+
149
+ if text_config is None:
150
+ text_config = {}
151
+
152
+ self.vision_config = CheXagentVisionConfig(**vision_config)
153
+ self.qformer_config = CheXagentQFormerConfig(**qformer_config)
154
+ text_model_type = text_config["model_type"] if "model_type" in text_config else "opt"
155
+ self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
156
+
157
+ self.tie_word_embeddings = self.text_config.tie_word_embeddings
158
+ self.is_encoder_decoder = self.text_config.is_encoder_decoder
159
+
160
+ self.num_query_tokens = num_query_tokens
161
+ self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
162
+ self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
163
+ self.initializer_factor = 1.0
164
+ self.initializer_range = 0.02
165
+ self.num_max_images = num_max_images
166
+
167
+ @classmethod
168
+ def from_vision_qformer_text_configs(
169
+ cls,
170
+ vision_config: CheXagentVisionConfig,
171
+ qformer_config: CheXagentQFormerConfig,
172
+ text_config: PretrainedConfig,
173
+ **kwargs,
174
+ ):
175
+ return cls(
176
+ vision_config=vision_config.to_dict(),
177
+ qformer_config=qformer_config.to_dict(),
178
+ text_config=text_config.to_dict(),
179
+ **kwargs,
180
+ )
modeling_chexagent.py ADDED
@@ -0,0 +1,1300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The CheXagent Authors, The Salesforce Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from dataclasses import dataclass
18
+ from typing import Any, Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.utils.checkpoint
22
+ from einops import rearrange
23
+ from torch import nn
24
+ from torch.nn import CrossEntropyLoss
25
+ from transformers.activations import ACT2FN
26
+ from transformers.modeling_outputs import (
27
+ BaseModelOutput,
28
+ BaseModelOutputWithPastAndCrossAttentions,
29
+ BaseModelOutputWithPooling,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ )
32
+ from transformers.modeling_utils import PreTrainedModel
33
+ from transformers.models.auto import AutoModelForCausalLM, AutoModelForSeq2SeqLM
34
+ from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
35
+ from transformers.utils import ModelOutput, logging
36
+
37
+ from .configuration_chexagent import CheXagentConfig, CheXagentQFormerConfig, CheXagentVisionConfig
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+
42
+ @dataclass
43
+ class CheXagentForConditionalGenerationModelOutput(ModelOutput):
44
+ loss: Optional[Tuple[torch.FloatTensor]] = None
45
+ logits: Optional[Tuple[torch.FloatTensor]] = None
46
+ vision_outputs: Optional[torch.FloatTensor] = None
47
+ qformer_outputs: Optional[Tuple[torch.FloatTensor]] = None
48
+ language_model_outputs: Optional[Tuple[torch.FloatTensor]] = None
49
+
50
+ def to_tuple(self) -> Tuple[Any]:
51
+ return tuple(
52
+ self[k]
53
+ if k not in ["vision_outputs", "qformer_outputs", "language_model_outputs"]
54
+ else getattr(self, k).to_tuple()
55
+ for k in self.keys()
56
+ )
57
+
58
+
59
+ class CheXagentVisionEmbeddings(nn.Module):
60
+ def __init__(self, config: CheXagentVisionConfig):
61
+ super().__init__()
62
+ self.config = config
63
+ self.embed_dim = config.hidden_size
64
+ self.image_size = config.image_size
65
+ self.patch_size = config.patch_size
66
+
67
+ self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
68
+
69
+ self.patch_embedding = nn.Conv2d(
70
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
71
+ )
72
+
73
+ self.num_patches = (self.image_size // self.patch_size) ** 2
74
+ self.num_positions = self.num_patches + 1
75
+
76
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
77
+
78
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
79
+ batch_size = pixel_values.shape[0]
80
+ target_dtype = self.patch_embedding.weight.dtype
81
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
82
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
83
+
84
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
85
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
86
+ embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
87
+ return embeddings
88
+
89
+
90
+ class CheXagentAttention(nn.Module):
91
+ def __init__(self, config):
92
+ super().__init__()
93
+ self.config = config
94
+ self.embed_dim = config.hidden_size
95
+ self.num_heads = config.num_attention_heads
96
+ self.head_dim = self.embed_dim // self.num_heads
97
+ if self.head_dim * self.num_heads != self.embed_dim:
98
+ raise ValueError(
99
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
100
+ f" {self.num_heads})."
101
+ )
102
+ self.scale = self.head_dim ** -0.5
103
+ self.dropout = nn.Dropout(config.attention_dropout)
104
+
105
+ # small tweak here compared to CLIP, no bias here
106
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
107
+
108
+ if config.qkv_bias:
109
+ q_bias = nn.Parameter(torch.zeros(self.embed_dim))
110
+ v_bias = nn.Parameter(torch.zeros(self.embed_dim))
111
+ else:
112
+ q_bias = None
113
+ v_bias = None
114
+
115
+ if q_bias is not None:
116
+ qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
117
+ self.qkv.bias = nn.Parameter(qkv_bias)
118
+
119
+ self.projection = nn.Linear(self.embed_dim, self.embed_dim)
120
+
121
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
122
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
123
+
124
+ def forward(
125
+ self,
126
+ hidden_states: torch.Tensor,
127
+ head_mask: Optional[torch.Tensor] = None,
128
+ output_attentions: Optional[bool] = False,
129
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
130
+ bsz, tgt_len, embed_dim = hidden_states.size()
131
+
132
+ mixed_qkv = self.qkv(hidden_states)
133
+
134
+ mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
135
+ 2, 0, 3, 1, 4
136
+ )
137
+ query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
138
+
139
+ # Take the dot product between "query" and "key" to get the raw attention scores.
140
+ attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
141
+
142
+ attention_scores = attention_scores * self.scale
143
+
144
+ # Normalize the attention scores to probabilities.
145
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
146
+
147
+ # This is actually dropping out entire tokens to attend to, which might
148
+ # seem a bit unusual, but is taken from the original Transformer paper.
149
+ attention_probs = self.dropout(attention_probs)
150
+
151
+ # Mask heads if we want to
152
+ if head_mask is not None:
153
+ attention_probs = attention_probs * head_mask
154
+
155
+ context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
156
+
157
+ new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
158
+ context_layer = context_layer.reshape(new_context_layer_shape)
159
+
160
+ output = self.projection(context_layer)
161
+
162
+ outputs = (output, attention_probs) if output_attentions else (output, None)
163
+
164
+ return outputs
165
+
166
+
167
+ class CheXagentMLP(nn.Module):
168
+ def __init__(self, config):
169
+ super().__init__()
170
+ self.config = config
171
+ self.activation_fn = ACT2FN[config.hidden_act]
172
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
173
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
174
+
175
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
176
+ hidden_states = self.fc1(hidden_states)
177
+ hidden_states = self.activation_fn(hidden_states)
178
+ hidden_states = self.fc2(hidden_states)
179
+ return hidden_states
180
+
181
+
182
+ class CheXagentEncoderLayer(nn.Module):
183
+ def __init__(self, config: CheXagentConfig):
184
+ super().__init__()
185
+ self.embed_dim = config.hidden_size
186
+ self.self_attn = CheXagentAttention(config)
187
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
188
+ self.mlp = CheXagentMLP(config)
189
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
190
+
191
+ def forward(
192
+ self,
193
+ hidden_states: torch.Tensor,
194
+ attention_mask: torch.Tensor,
195
+ output_attentions: Optional[bool] = False,
196
+ ) -> Tuple[torch.FloatTensor]:
197
+ residual = hidden_states
198
+ hidden_states = self.layer_norm1(hidden_states)
199
+ hidden_states, attn_weights = self.self_attn(
200
+ hidden_states=hidden_states,
201
+ head_mask=attention_mask,
202
+ output_attentions=output_attentions,
203
+ )
204
+ hidden_states = hidden_states + residual
205
+ residual = hidden_states
206
+ hidden_states = self.layer_norm2(hidden_states)
207
+ hidden_states = self.mlp(hidden_states)
208
+
209
+ hidden_states = hidden_states + residual
210
+
211
+ outputs = (hidden_states,)
212
+
213
+ if output_attentions:
214
+ outputs += (attn_weights,)
215
+
216
+ return outputs
217
+
218
+
219
+ class CheXagentPreTrainedModel(PreTrainedModel):
220
+ config_class = CheXagentConfig
221
+ base_model_prefix = "chexagent"
222
+ supports_gradient_checkpointing = True
223
+ _no_split_modules = [
224
+ "CheXagentQFormerEmbeddings",
225
+ "CheXagentAttention",
226
+ "CheXagentQFormerMultiHeadAttention",
227
+ "CheXagentQFormerSelfOutput",
228
+ ]
229
+ _keep_in_fp32_modules = []
230
+
231
+ def _init_weights(self, module):
232
+ """Initialize the weights"""
233
+ factor = self.config.initializer_range
234
+ if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):
235
+ module.weight.data.normal_(mean=0.0, std=factor)
236
+ if hasattr(module, "bias") and module.bias is not None:
237
+ module.bias.data.zero_()
238
+
239
+ if isinstance(module, CheXagentVisionEmbeddings):
240
+ if hasattr(self.config, "vision_config"):
241
+ factor = self.config.vision_config.initializer_range
242
+ nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
243
+ nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
244
+
245
+ elif isinstance(module, nn.LayerNorm):
246
+ module.bias.data.zero_()
247
+ module.weight.data.fill_(1.0)
248
+ elif isinstance(module, nn.Linear) and module.bias is not None:
249
+ module.bias.data.zero_()
250
+
251
+
252
+ class CheXagentEncoder(nn.Module):
253
+ def __init__(self, config: CheXagentConfig):
254
+ super().__init__()
255
+ self.config = config
256
+ self.layers = nn.ModuleList([CheXagentEncoderLayer(config) for _ in range(config.num_hidden_layers)])
257
+ self.gradient_checkpointing = False
258
+
259
+ def forward(
260
+ self,
261
+ inputs_embeds,
262
+ attention_mask: Optional[torch.Tensor] = None,
263
+ output_attentions: Optional[bool] = None,
264
+ output_hidden_states: Optional[bool] = None,
265
+ return_dict: Optional[bool] = None,
266
+ ) -> Union[Tuple, BaseModelOutput]:
267
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
268
+ output_hidden_states = (
269
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
270
+ )
271
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
272
+
273
+ encoder_states = () if output_hidden_states else None
274
+ all_attentions = () if output_attentions else None
275
+ hidden_states = inputs_embeds
276
+ for idx, encoder_layer in enumerate(self.layers):
277
+ if output_hidden_states:
278
+ encoder_states = encoder_states + (hidden_states,)
279
+ if self.gradient_checkpointing and self.training:
280
+ layer_outputs = self._gradient_checkpointing_func(
281
+ encoder_layer.__call__,
282
+ hidden_states,
283
+ attention_mask,
284
+ output_attentions,
285
+ )
286
+ else:
287
+ layer_outputs = encoder_layer(hidden_states, attention_mask, output_attentions=output_attentions, )
288
+
289
+ hidden_states = layer_outputs[0]
290
+
291
+ if output_attentions:
292
+ all_attentions = all_attentions + (layer_outputs[1],)
293
+
294
+ if output_hidden_states:
295
+ encoder_states = encoder_states + (hidden_states,)
296
+
297
+ if not return_dict:
298
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
299
+ return BaseModelOutput(
300
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
301
+ )
302
+
303
+
304
+ class CheXagentVisionModel(CheXagentPreTrainedModel):
305
+ main_input_name = "pixel_values"
306
+ config_class = CheXagentVisionConfig
307
+
308
+ def __init__(self, config: CheXagentVisionConfig):
309
+ super().__init__(config)
310
+ self.config = config
311
+ embed_dim = config.hidden_size
312
+
313
+ self.embeddings = CheXagentVisionEmbeddings(config)
314
+ self.encoder = CheXagentEncoder(config)
315
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
316
+
317
+ self.post_init()
318
+
319
+ def forward(
320
+ self,
321
+ pixel_values: Optional[torch.FloatTensor] = None,
322
+ output_attentions: Optional[bool] = None,
323
+ output_hidden_states: Optional[bool] = None,
324
+ return_dict: Optional[bool] = None,
325
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
326
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
327
+ output_hidden_states = (
328
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
329
+ )
330
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
331
+
332
+ if pixel_values is None:
333
+ raise ValueError("You have to specify pixel_values")
334
+ hidden_states = self.embeddings(pixel_values)
335
+
336
+ encoder_outputs = self.encoder(
337
+ inputs_embeds=hidden_states,
338
+ output_attentions=output_attentions,
339
+ output_hidden_states=output_hidden_states,
340
+ return_dict=return_dict,
341
+ )
342
+
343
+ last_hidden_state = encoder_outputs[0]
344
+ last_hidden_state = self.post_layernorm(last_hidden_state)
345
+
346
+ pooled_output = last_hidden_state[:, 0, :]
347
+ pooled_output = self.post_layernorm(pooled_output)
348
+
349
+ if not return_dict:
350
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
351
+
352
+ return BaseModelOutputWithPooling(
353
+ last_hidden_state=last_hidden_state,
354
+ pooler_output=pooled_output,
355
+ hidden_states=encoder_outputs.hidden_states,
356
+ attentions=encoder_outputs.attentions,
357
+ )
358
+
359
+ def get_input_embeddings(self):
360
+ return self.embeddings
361
+
362
+
363
+ class CheXagentQFormerMultiHeadAttention(nn.Module):
364
+ def __init__(self, config, is_cross_attention=False):
365
+ super().__init__()
366
+ self.config = config
367
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
368
+ raise ValueError(
369
+ "The hidden size (%d) is not a multiple of the number of attention heads (%d)"
370
+ % (config.hidden_size, config.num_attention_heads)
371
+ )
372
+
373
+ self.num_attention_heads = config.num_attention_heads
374
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
375
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
376
+
377
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
378
+ if is_cross_attention:
379
+ self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size)
380
+ self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size)
381
+ else:
382
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
383
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
384
+
385
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
386
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
387
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
388
+ self.max_position_embeddings = config.max_position_embeddings
389
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
390
+ self.save_attention = False
391
+
392
+ def save_attn_gradients(self, attn_gradients):
393
+ self.attn_gradients = attn_gradients
394
+
395
+ def get_attn_gradients(self):
396
+ return self.attn_gradients
397
+
398
+ def save_attention_map(self, attention_map):
399
+ self.attention_map = attention_map
400
+
401
+ def get_attention_map(self):
402
+ return self.attention_map
403
+
404
+ def transpose_for_scores(self, x):
405
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
406
+ x = x.view(*new_x_shape)
407
+ return x.permute(0, 2, 1, 3)
408
+
409
+ def forward(
410
+ self,
411
+ hidden_states,
412
+ attention_mask=None,
413
+ head_mask=None,
414
+ encoder_hidden_states=None,
415
+ encoder_attention_mask=None,
416
+ past_key_value=None,
417
+ output_attentions=False,
418
+ ):
419
+ # If this is instantiated as a cross-attention module, the keys
420
+ # and values come from an encoder; the attention mask needs to be
421
+ # such that the encoder's padding tokens are not attended to.
422
+ is_cross_attention = encoder_hidden_states is not None
423
+
424
+ if is_cross_attention:
425
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
426
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
427
+ attention_mask = encoder_attention_mask
428
+ elif past_key_value is not None:
429
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
430
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
431
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
432
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
433
+ else:
434
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
435
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
436
+
437
+ mixed_query_layer = self.query(hidden_states)
438
+
439
+ query_layer = self.transpose_for_scores(mixed_query_layer)
440
+
441
+ past_key_value = (key_layer, value_layer)
442
+
443
+ # Take the dot product between "query" and "key" to get the raw attention scores.
444
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
445
+
446
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
447
+ seq_length = hidden_states.size()[1]
448
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
449
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
450
+ distance = position_ids_l - position_ids_r
451
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
452
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
453
+
454
+ if self.position_embedding_type == "relative_key":
455
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
456
+ attention_scores = attention_scores + relative_position_scores
457
+ elif self.position_embedding_type == "relative_key_query":
458
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
459
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
460
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
461
+
462
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
463
+
464
+ if attention_mask is not None:
465
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
466
+ attention_scores = attention_scores + attention_mask
467
+
468
+ # Normalize the attention scores to probabilities.
469
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
470
+
471
+ if is_cross_attention and self.save_attention:
472
+ self.save_attention_map(attention_probs)
473
+ attention_probs.register_hook(self.save_attn_gradients)
474
+
475
+ # This is actually dropping out entire tokens to attend to, which might
476
+ # seem a bit unusual, but is taken from the original Transformer paper.
477
+ attention_probs_dropped = self.dropout(attention_probs)
478
+
479
+ # Mask heads if we want to
480
+ if head_mask is not None:
481
+ attention_probs_dropped = attention_probs_dropped * head_mask
482
+
483
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
484
+
485
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
486
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
487
+ context_layer = context_layer.view(*new_context_layer_shape)
488
+
489
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
490
+
491
+ outputs = outputs + (past_key_value,)
492
+ return outputs
493
+
494
+
495
+ class CheXagentQFormerSelfOutput(nn.Module):
496
+ def __init__(self, config):
497
+ super().__init__()
498
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
499
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
500
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
501
+
502
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
503
+ hidden_states = self.dense(hidden_states)
504
+ hidden_states = self.dropout(hidden_states)
505
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
506
+ return hidden_states
507
+
508
+
509
+ class CheXagentQFormerAttention(nn.Module):
510
+ def __init__(self, config, is_cross_attention=False):
511
+ super().__init__()
512
+ self.attention = CheXagentQFormerMultiHeadAttention(config, is_cross_attention)
513
+ self.output = CheXagentQFormerSelfOutput(config)
514
+ self.pruned_heads = set()
515
+
516
+ def prune_heads(self, heads):
517
+ if len(heads) == 0:
518
+ return
519
+ heads, index = find_pruneable_heads_and_indices(
520
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
521
+ )
522
+
523
+ # Prune linear layers
524
+ self.attention.query = prune_linear_layer(self.attention.query, index)
525
+ self.attention.key = prune_linear_layer(self.attention.key, index)
526
+ self.attention.value = prune_linear_layer(self.attention.value, index)
527
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
528
+
529
+ # Update hyper params and store pruned heads
530
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
531
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
532
+ self.pruned_heads = self.pruned_heads.union(heads)
533
+
534
+ def forward(
535
+ self,
536
+ hidden_states: torch.Tensor,
537
+ attention_mask: Optional[torch.FloatTensor] = None,
538
+ head_mask: Optional[torch.FloatTensor] = None,
539
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
540
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
541
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
542
+ output_attentions: Optional[bool] = False,
543
+ ) -> Tuple[torch.Tensor]:
544
+ self_outputs = self.attention(
545
+ hidden_states,
546
+ attention_mask,
547
+ head_mask,
548
+ encoder_hidden_states,
549
+ encoder_attention_mask,
550
+ past_key_value,
551
+ output_attentions,
552
+ )
553
+ attention_output = self.output(self_outputs[0], hidden_states)
554
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
555
+ return outputs
556
+
557
+
558
+ class CheXagentQFormerIntermediate(nn.Module):
559
+ def __init__(self, config):
560
+ super().__init__()
561
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
562
+ if isinstance(config.hidden_act, str):
563
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
564
+ else:
565
+ self.intermediate_act_fn = config.hidden_act
566
+
567
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
568
+ hidden_states = self.dense(hidden_states)
569
+ hidden_states = self.intermediate_act_fn(hidden_states)
570
+ return hidden_states
571
+
572
+
573
+ class CheXagentQFormerOutput(nn.Module):
574
+ def __init__(self, config):
575
+ super().__init__()
576
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
577
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
578
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
579
+
580
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
581
+ hidden_states = self.dense(hidden_states)
582
+ hidden_states = self.dropout(hidden_states)
583
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
584
+ return hidden_states
585
+
586
+
587
+ class CheXagentQFormerLayer(nn.Module):
588
+ def __init__(self, config, layer_idx):
589
+ super().__init__()
590
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
591
+ self.seq_len_dim = 1
592
+ self.attention = CheXagentQFormerAttention(config)
593
+
594
+ self.layer_idx = layer_idx
595
+
596
+ if layer_idx % config.cross_attention_frequency == 0:
597
+ self.crossattention = CheXagentQFormerAttention(config, is_cross_attention=True)
598
+ self.has_cross_attention = True
599
+ else:
600
+ self.has_cross_attention = False
601
+
602
+ self.intermediate_query = CheXagentQFormerIntermediate(config)
603
+ self.output_query = CheXagentQFormerOutput(config)
604
+
605
+ def forward(
606
+ self,
607
+ hidden_states,
608
+ attention_mask=None,
609
+ head_mask=None,
610
+ encoder_hidden_states=None,
611
+ encoder_attention_mask=None,
612
+ past_key_value=None,
613
+ output_attentions=False,
614
+ query_length=0,
615
+ ):
616
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
617
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
618
+ self_attention_outputs = self.attention(
619
+ hidden_states,
620
+ attention_mask,
621
+ head_mask,
622
+ output_attentions=output_attentions,
623
+ past_key_value=self_attn_past_key_value,
624
+ )
625
+ attention_output = self_attention_outputs[0]
626
+ outputs = self_attention_outputs[1:-1]
627
+
628
+ present_key_value = self_attention_outputs[-1]
629
+
630
+ if query_length > 0:
631
+ query_attention_output = attention_output[:, :query_length, :]
632
+
633
+ if self.has_cross_attention:
634
+ if encoder_hidden_states is None:
635
+ raise ValueError("encoder_hidden_states must be given for cross-attention layers")
636
+ cross_attention_outputs = self.crossattention(
637
+ query_attention_output,
638
+ attention_mask,
639
+ head_mask,
640
+ encoder_hidden_states,
641
+ encoder_attention_mask,
642
+ output_attentions=output_attentions,
643
+ )
644
+ query_attention_output = cross_attention_outputs[0]
645
+ # add cross attentions if we output attention weights
646
+ outputs = outputs + cross_attention_outputs[1:-1]
647
+
648
+ layer_output = apply_chunking_to_forward(
649
+ self.feed_forward_chunk_query,
650
+ self.chunk_size_feed_forward,
651
+ self.seq_len_dim,
652
+ query_attention_output,
653
+ )
654
+
655
+ if attention_output.shape[1] > query_length:
656
+ layer_output_text = apply_chunking_to_forward(
657
+ self.feed_forward_chunk,
658
+ self.chunk_size_feed_forward,
659
+ self.seq_len_dim,
660
+ attention_output[:, query_length:, :],
661
+ )
662
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
663
+ else:
664
+ layer_output = apply_chunking_to_forward(
665
+ self.feed_forward_chunk,
666
+ self.chunk_size_feed_forward,
667
+ self.seq_len_dim,
668
+ attention_output,
669
+ )
670
+ outputs = (layer_output,) + outputs
671
+
672
+ outputs = outputs + (present_key_value,)
673
+
674
+ return outputs
675
+
676
+ def feed_forward_chunk(self, attention_output):
677
+ intermediate_output = self.intermediate(attention_output)
678
+ layer_output = self.output(intermediate_output, attention_output)
679
+ return layer_output
680
+
681
+ def feed_forward_chunk_query(self, attention_output):
682
+ intermediate_output = self.intermediate_query(attention_output)
683
+ layer_output = self.output_query(intermediate_output, attention_output)
684
+ return layer_output
685
+
686
+
687
+ class CheXagentQFormerEncoder(nn.Module):
688
+ def __init__(self, config):
689
+ super().__init__()
690
+ self.config = config
691
+ self.layer = nn.ModuleList(
692
+ [CheXagentQFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
693
+ )
694
+ self.gradient_checkpointing = False
695
+
696
+ def forward(
697
+ self,
698
+ hidden_states,
699
+ attention_mask=None,
700
+ head_mask=None,
701
+ encoder_hidden_states=None,
702
+ encoder_attention_mask=None,
703
+ past_key_values=None,
704
+ use_cache=None,
705
+ output_attentions=False,
706
+ output_hidden_states=False,
707
+ return_dict=True,
708
+ query_length=0,
709
+ ):
710
+ all_hidden_states = () if output_hidden_states else None
711
+ all_self_attentions = () if output_attentions else None
712
+ all_cross_attentions = () if output_attentions else None
713
+
714
+ next_decoder_cache = () if use_cache else None
715
+
716
+ for i in range(self.config.num_hidden_layers):
717
+ layer_module = self.layer[i]
718
+ if output_hidden_states:
719
+ all_hidden_states = all_hidden_states + (hidden_states,)
720
+
721
+ layer_head_mask = head_mask[i] if head_mask is not None else None
722
+ past_key_value = past_key_values[i] if past_key_values is not None else None
723
+
724
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
725
+ if use_cache:
726
+ logger.warning(
727
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
728
+ )
729
+ use_cache = False
730
+ layer_outputs = self._gradient_checkpointing_func(
731
+ layer_module.__call__,
732
+ hidden_states,
733
+ attention_mask,
734
+ layer_head_mask,
735
+ encoder_hidden_states,
736
+ encoder_attention_mask,
737
+ )
738
+ else:
739
+ layer_outputs = layer_module(
740
+ hidden_states,
741
+ attention_mask,
742
+ layer_head_mask,
743
+ encoder_hidden_states,
744
+ encoder_attention_mask,
745
+ past_key_value,
746
+ output_attentions,
747
+ query_length,
748
+ )
749
+
750
+ hidden_states = layer_outputs[0]
751
+ if use_cache:
752
+ next_decoder_cache += (layer_outputs[-1],)
753
+ if output_attentions:
754
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
755
+ if layer_module.has_cross_attention:
756
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
757
+
758
+ if output_hidden_states:
759
+ all_hidden_states = all_hidden_states + (hidden_states,)
760
+
761
+ if not return_dict:
762
+ return tuple(
763
+ v
764
+ for v in [
765
+ hidden_states,
766
+ next_decoder_cache,
767
+ all_hidden_states,
768
+ all_self_attentions,
769
+ all_cross_attentions,
770
+ ]
771
+ if v is not None
772
+ )
773
+ return BaseModelOutputWithPastAndCrossAttentions(
774
+ last_hidden_state=hidden_states,
775
+ past_key_values=next_decoder_cache,
776
+ hidden_states=all_hidden_states,
777
+ attentions=all_self_attentions,
778
+ cross_attentions=all_cross_attentions,
779
+ )
780
+
781
+
782
+ class CheXagentQFormerEmbeddings(nn.Module):
783
+ def __init__(self, config):
784
+ super().__init__()
785
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
786
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
787
+
788
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
789
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
790
+
791
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
792
+ self.register_buffer(
793
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
794
+ )
795
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
796
+
797
+ self.config = config
798
+
799
+ def forward(
800
+ self,
801
+ input_ids=None,
802
+ position_ids=None,
803
+ query_embeds=None,
804
+ past_key_values_length=0,
805
+ ):
806
+ if input_ids is not None:
807
+ seq_length = input_ids.size()[1]
808
+ else:
809
+ seq_length = 0
810
+
811
+ if position_ids is None:
812
+ position_ids = self.position_ids[:, past_key_values_length: seq_length + past_key_values_length].clone()
813
+
814
+ if input_ids is not None:
815
+ embeddings = self.word_embeddings(input_ids)
816
+ if self.position_embedding_type == "absolute":
817
+ position_embeddings = self.position_embeddings(position_ids.to(embeddings.device))
818
+ embeddings = embeddings + position_embeddings
819
+
820
+ if query_embeds is not None:
821
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
822
+ else:
823
+ embeddings = query_embeds
824
+
825
+ embeddings = embeddings.to(self.layernorm.weight.dtype)
826
+ embeddings = self.layernorm(embeddings)
827
+ embeddings = self.dropout(embeddings)
828
+ return embeddings
829
+
830
+
831
+ class CheXagentQFormerEncoder(nn.Module):
832
+ def __init__(self, config):
833
+ super().__init__()
834
+ self.config = config
835
+ self.layer = nn.ModuleList(
836
+ [CheXagentQFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
837
+ )
838
+ self.gradient_checkpointing = False
839
+
840
+ def forward(
841
+ self,
842
+ hidden_states,
843
+ attention_mask=None,
844
+ head_mask=None,
845
+ encoder_hidden_states=None,
846
+ encoder_attention_mask=None,
847
+ past_key_values=None,
848
+ use_cache=None,
849
+ output_attentions=False,
850
+ output_hidden_states=False,
851
+ return_dict=True,
852
+ query_length=0,
853
+ ):
854
+ all_hidden_states = () if output_hidden_states else None
855
+ all_self_attentions = () if output_attentions else None
856
+ all_cross_attentions = () if output_attentions else None
857
+
858
+ next_decoder_cache = () if use_cache else None
859
+
860
+ for i in range(self.config.num_hidden_layers):
861
+ layer_module = self.layer[i]
862
+ if output_hidden_states:
863
+ all_hidden_states = all_hidden_states + (hidden_states,)
864
+
865
+ layer_head_mask = head_mask[i] if head_mask is not None else None
866
+ past_key_value = past_key_values[i] if past_key_values is not None else None
867
+
868
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
869
+ if use_cache:
870
+ logger.warning(
871
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
872
+ )
873
+ use_cache = False
874
+ layer_outputs = self._gradient_checkpointing_func(
875
+ layer_module.__call__,
876
+ hidden_states,
877
+ attention_mask,
878
+ layer_head_mask,
879
+ encoder_hidden_states,
880
+ encoder_attention_mask,
881
+ )
882
+ else:
883
+ layer_outputs = layer_module(
884
+ hidden_states,
885
+ attention_mask,
886
+ layer_head_mask,
887
+ encoder_hidden_states,
888
+ encoder_attention_mask,
889
+ past_key_value,
890
+ output_attentions,
891
+ query_length,
892
+ )
893
+
894
+ hidden_states = layer_outputs[0]
895
+ if use_cache:
896
+ next_decoder_cache += (layer_outputs[-1],)
897
+ if output_attentions:
898
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
899
+ if layer_module.has_cross_attention:
900
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
901
+
902
+ if output_hidden_states:
903
+ all_hidden_states = all_hidden_states + (hidden_states,)
904
+
905
+ if not return_dict:
906
+ return tuple(
907
+ v
908
+ for v in [
909
+ hidden_states,
910
+ next_decoder_cache,
911
+ all_hidden_states,
912
+ all_self_attentions,
913
+ all_cross_attentions,
914
+ ]
915
+ if v is not None
916
+ )
917
+ return BaseModelOutputWithPastAndCrossAttentions(
918
+ last_hidden_state=hidden_states,
919
+ past_key_values=next_decoder_cache,
920
+ hidden_states=all_hidden_states,
921
+ attentions=all_self_attentions,
922
+ cross_attentions=all_cross_attentions,
923
+ )
924
+
925
+
926
+ class CheXagentQFormerModel(CheXagentPreTrainedModel):
927
+ def __init__(self, config: CheXagentQFormerConfig):
928
+ super().__init__(config)
929
+ self.config = config
930
+
931
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
932
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
933
+
934
+ self.encoder = CheXagentQFormerEncoder(config)
935
+
936
+ self.post_init()
937
+
938
+ def get_input_embeddings(self):
939
+ return self.embeddings.word_embeddings
940
+
941
+ def set_input_embeddings(self, value):
942
+ self.embeddings.word_embeddings = value
943
+
944
+ def _prune_heads(self, heads_to_prune):
945
+ for layer, heads in heads_to_prune.items():
946
+ self.encoder.layer[layer].attention.prune_heads(heads)
947
+
948
+ def get_extended_attention_mask(
949
+ self,
950
+ attention_mask: torch.Tensor,
951
+ input_shape: Tuple[int],
952
+ device: torch.device,
953
+ has_query: bool = False,
954
+ ) -> torch.Tensor:
955
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
956
+ # ourselves in which case we just need to make it broadcastable to all heads.
957
+ if attention_mask.dim() == 3:
958
+ extended_attention_mask = attention_mask[:, None, :, :]
959
+ elif attention_mask.dim() == 2:
960
+ # Provided a padding mask of dimensions [batch_size, seq_length]
961
+ # - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
962
+ extended_attention_mask = attention_mask[:, None, None, :]
963
+ else:
964
+ raise ValueError(
965
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
966
+ input_shape, attention_mask.shape
967
+ )
968
+ )
969
+
970
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
971
+ # masked positions, this operation will create a tensor which is 0.0 for
972
+ # positions we want to attend and -10000.0 for masked positions.
973
+ # Since we are adding it to the raw scores before the softmax, this is
974
+ # effectively the same as removing these entirely.
975
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
976
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
977
+ return extended_attention_mask
978
+
979
+ def forward(
980
+ self,
981
+ query_embeds: torch.FloatTensor,
982
+ attention_mask: Optional[torch.FloatTensor] = None,
983
+ head_mask: Optional[torch.FloatTensor] = None,
984
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
985
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
986
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
987
+ use_cache: Optional[bool] = None,
988
+ output_attentions: Optional[bool] = None,
989
+ output_hidden_states: Optional[bool] = None,
990
+ return_dict: Optional[bool] = None,
991
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
992
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
993
+ output_hidden_states = (
994
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
995
+ )
996
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
997
+
998
+ # past_key_values_length
999
+ past_key_values_length = (
1000
+ past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
1001
+ )
1002
+
1003
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
1004
+
1005
+ embedding_output = self.layernorm(query_embeds)
1006
+ embedding_output = self.dropout(embedding_output)
1007
+
1008
+ input_shape = embedding_output.size()[:-1]
1009
+ batch_size, seq_length = input_shape
1010
+ device = embedding_output.device
1011
+
1012
+ if attention_mask is None:
1013
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
1014
+
1015
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1016
+ # ourselves in which case we just need to make it broadcastable to all heads.
1017
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
1018
+
1019
+ # If a 2D or 3D attention mask is provided for the cross-attention
1020
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1021
+ if encoder_hidden_states is not None:
1022
+ if type(encoder_hidden_states) == list:
1023
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
1024
+ else:
1025
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
1026
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1027
+
1028
+ if type(encoder_attention_mask) == list:
1029
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
1030
+ elif encoder_attention_mask is None:
1031
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1032
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1033
+ else:
1034
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1035
+ else:
1036
+ encoder_extended_attention_mask = None
1037
+
1038
+ # Prepare head mask if needed
1039
+ # 1.0 in head_mask indicate we keep the head
1040
+ # attention_probs has shape bsz x n_heads x N x N
1041
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1042
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1043
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1044
+
1045
+ encoder_outputs = self.encoder(
1046
+ embedding_output,
1047
+ attention_mask=extended_attention_mask,
1048
+ head_mask=head_mask,
1049
+ encoder_hidden_states=encoder_hidden_states,
1050
+ encoder_attention_mask=encoder_extended_attention_mask,
1051
+ past_key_values=past_key_values,
1052
+ use_cache=use_cache,
1053
+ output_attentions=output_attentions,
1054
+ output_hidden_states=output_hidden_states,
1055
+ return_dict=return_dict,
1056
+ query_length=query_length,
1057
+ )
1058
+ sequence_output = encoder_outputs[0]
1059
+ pooled_output = sequence_output[:, 0, :]
1060
+
1061
+ if not return_dict:
1062
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1063
+
1064
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1065
+ last_hidden_state=sequence_output,
1066
+ pooler_output=pooled_output,
1067
+ past_key_values=encoder_outputs.past_key_values,
1068
+ hidden_states=encoder_outputs.hidden_states,
1069
+ attentions=encoder_outputs.attentions,
1070
+ cross_attentions=encoder_outputs.cross_attentions,
1071
+ )
1072
+
1073
+
1074
+ class CheXagentForConditionalGeneration(CheXagentPreTrainedModel):
1075
+ config_class = CheXagentConfig
1076
+ main_input_name = "pixel_values"
1077
+
1078
+ def __init__(self, config: CheXagentConfig):
1079
+ super().__init__(config)
1080
+
1081
+ self.vision_model = CheXagentVisionModel(config.vision_config)
1082
+
1083
+ self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
1084
+ self.qformer = CheXagentQFormerModel(config.qformer_config)
1085
+
1086
+ self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
1087
+ if config.use_decoder_only_language_model:
1088
+ language_model = AutoModelForCausalLM.from_config(config.text_config)
1089
+ else:
1090
+ language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
1091
+
1092
+ # Update _tied_weights_keys using the base model used.
1093
+ if language_model._tied_weights_keys is not None:
1094
+ self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
1095
+
1096
+ self.language_model = language_model
1097
+
1098
+ # Initialize weights and apply final processing
1099
+ self.post_init()
1100
+
1101
+ def get_input_embeddings(self):
1102
+ return self.language_model.get_input_embeddings()
1103
+
1104
+ def set_input_embeddings(self, value):
1105
+ self.language_model.set_input_embeddings(value)
1106
+
1107
+ def set_output_embeddings(self, new_embeddings):
1108
+ self.language_model.set_output_embeddings(new_embeddings)
1109
+
1110
+ def get_output_embeddings(self) -> nn.Module:
1111
+ return self.language_model.get_output_embeddings()
1112
+
1113
+ def get_encoder(self):
1114
+ return self.language_model.get_encoder()
1115
+
1116
+ def get_decoder(self):
1117
+ return self.language_model.get_decoder()
1118
+
1119
+ def _tie_weights(self):
1120
+ if not self.config.use_decoder_only_language_model:
1121
+ self.language_model.encoder.embed_tokens = self.language_model.shared
1122
+ self.language_model.decoder.embed_tokens = self.language_model.shared
1123
+
1124
+ def _preprocess_accelerate(self):
1125
+ hf_device_map = self.hf_device_map
1126
+ if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
1127
+ # warn users about unexpected behavior when using multi-GPU + BLIP-2 + `accelerate`.
1128
+ logger.warning(
1129
+ "The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
1130
+ " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
1131
+ " Please pass a `device_map` that contains `language_model` to remove this warning."
1132
+ " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
1133
+ " more details on creating a `device_map` for large models.",
1134
+ )
1135
+ if hasattr(self.language_model, "_hf_hook"):
1136
+ self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
1137
+
1138
+ def forward(
1139
+ self,
1140
+ pixel_values: torch.FloatTensor = None,
1141
+ input_ids: torch.FloatTensor = None,
1142
+ attention_mask: Optional[torch.LongTensor] = None,
1143
+ output_attentions: Optional[bool] = None,
1144
+ output_hidden_states: Optional[bool] = None,
1145
+ labels: Optional[torch.LongTensor] = None,
1146
+ return_dict: Optional[bool] = None,
1147
+ ) -> Union[Tuple, CheXagentForConditionalGenerationModelOutput]:
1148
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1149
+
1150
+ vision_outputs, query_outputs = None, None
1151
+ if pixel_values is not None:
1152
+ # step 1: forward the images through the vision encoder,
1153
+ # to get image embeddings of shape (batch_size, seq_len, hidden_size)
1154
+ image_mask = pixel_values.sum(dim=(2, 3, 4)) != 0
1155
+ vision_outputs = self.vision_model(
1156
+ pixel_values=pixel_values[image_mask],
1157
+ output_attentions=output_attentions,
1158
+ output_hidden_states=output_hidden_states,
1159
+ return_dict=return_dict,
1160
+ )
1161
+ tmp = vision_outputs[0]
1162
+ image_embeds = tmp.new_zeros((*image_mask.shape, *tmp.shape[1:]))
1163
+ image_embeds[image_mask] = tmp
1164
+
1165
+ # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
1166
+ image_attention_mask = torch.zeros(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
1167
+ image_attention_mask[image_mask] = 1
1168
+
1169
+ image_embeds = rearrange(image_embeds, "b i n d -> b (i n) d")
1170
+ image_attention_mask = rearrange(image_attention_mask, "b i n -> b (i n)")
1171
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
1172
+ query_outputs = self.qformer(
1173
+ query_embeds=query_tokens,
1174
+ encoder_hidden_states=image_embeds,
1175
+ encoder_attention_mask=image_attention_mask,
1176
+ output_attentions=output_attentions,
1177
+ output_hidden_states=output_hidden_states,
1178
+ return_dict=return_dict,
1179
+ )
1180
+ query_output = query_outputs[0]
1181
+
1182
+ # step 3: project vision to language
1183
+ input_vis = self.language_projection(query_output)
1184
+ vis_atts = torch.ones(input_vis.size()[:-1], dtype=torch.long, device=input_vis.device)
1185
+
1186
+ # step 4: get the embeddings of the prompt
1187
+ inputs_lang = self.language_model.get_input_embeddings()(input_ids)
1188
+ lang_atts = attention_mask
1189
+ if lang_atts is None:
1190
+ lang_atts = torch.ones_like(input_ids)
1191
+
1192
+ # step 5: conditioned on the images and/or prompts
1193
+ if pixel_values is not None:
1194
+ inputs_embeds = torch.cat([input_vis, inputs_lang], dim=1)
1195
+ attention_mask = torch.cat([vis_atts, lang_atts], dim=1)
1196
+ else:
1197
+ inputs_embeds = inputs_lang
1198
+ attention_mask = lang_atts
1199
+
1200
+ outputs = self.language_model(
1201
+ inputs_embeds=inputs_embeds,
1202
+ attention_mask=attention_mask,
1203
+ output_attentions=output_attentions,
1204
+ output_hidden_states=output_hidden_states,
1205
+ return_dict=return_dict
1206
+ )
1207
+ logits = outputs.logits if return_dict else outputs[0]
1208
+
1209
+ loss = None
1210
+ # we compute the loss here since we need to take into account the sequence length of the query embeds
1211
+ if labels is not None:
1212
+ # make target
1213
+ empty_labels = torch.ones(vis_atts.size(), dtype=torch.long, device=input_ids.device).fill_(-100)
1214
+ labels = torch.cat([empty_labels, labels], dim=1)
1215
+ labels = labels.to(logits.device)
1216
+ logits = logits[:, -labels.size(1):, :]
1217
+ # Shift so that tokens < n predict n
1218
+ shift_logits = logits[..., :-1, :].contiguous()
1219
+ shift_labels = labels[..., 1:].contiguous().to(logits.device)
1220
+ # Flatten the tokens
1221
+ loss_fct = CrossEntropyLoss(reduction="mean")
1222
+ loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1))
1223
+
1224
+ if not return_dict:
1225
+ output = (logits, vision_outputs, query_outputs, outputs)
1226
+ return ((loss,) + output) if loss is not None else output
1227
+
1228
+ return CheXagentForConditionalGenerationModelOutput(
1229
+ loss=loss,
1230
+ logits=logits,
1231
+ vision_outputs=vision_outputs,
1232
+ qformer_outputs=query_outputs,
1233
+ language_model_outputs=outputs,
1234
+ )
1235
+
1236
+ @torch.no_grad()
1237
+ def generate(
1238
+ self,
1239
+ pixel_values: torch.FloatTensor = None,
1240
+ input_ids: Optional[torch.LongTensor] = None,
1241
+ attention_mask: Optional[torch.LongTensor] = None,
1242
+ **generate_kwargs,
1243
+ ) -> torch.LongTensor:
1244
+ if hasattr(self, "hf_device_map"):
1245
+ # preprocess for `accelerate`
1246
+ self._preprocess_accelerate()
1247
+
1248
+ batch_size = pixel_values.shape[0] if pixel_values is not None else input_ids.shape[0]
1249
+ if pixel_values is not None:
1250
+ # step 1: forward the images through the vision encoder
1251
+ image_mask = pixel_values.sum(dim=(2, 3, 4)) != 0
1252
+ vision_outputs = self.vision_model(pixel_values[image_mask], return_dict=True)
1253
+ tmp = vision_outputs[0]
1254
+ image_embeds = tmp.new_zeros((*image_mask.shape, *tmp.shape[1:]))
1255
+ image_embeds[image_mask] = tmp
1256
+
1257
+ # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
1258
+ image_attention_mask = torch.zeros(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
1259
+ image_attention_mask[image_mask] = 1
1260
+ image_embeds = rearrange(image_embeds, "b i n d -> b (i n) d")
1261
+ image_attention_mask = rearrange(image_attention_mask, "b i n -> b (i n)")
1262
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
1263
+ query_outputs = self.qformer(
1264
+ query_embeds=query_tokens,
1265
+ encoder_hidden_states=image_embeds,
1266
+ encoder_attention_mask=image_attention_mask,
1267
+ return_dict=True,
1268
+ )
1269
+ query_output = query_outputs.last_hidden_state
1270
+
1271
+ # step 3: project vision to language
1272
+ input_vis = self.language_projection(query_output)
1273
+ vis_atts = torch.ones(input_vis.size()[:-1], dtype=torch.long, device=input_vis.device)
1274
+
1275
+ # step 4: get the embeddings of the prompt
1276
+ if input_ids is None:
1277
+ input_ids = (
1278
+ torch.LongTensor([[self.config.text_config.bos_token_id]])
1279
+ .repeat(batch_size, 1)
1280
+ .to(next(self.parameters()).device)
1281
+ )
1282
+ inputs_lang = self.language_model.get_input_embeddings()(input_ids)
1283
+ lang_atts = attention_mask
1284
+ if lang_atts is None:
1285
+ lang_atts = torch.ones_like(input_ids)
1286
+
1287
+ # step 5: conditioned on the images and/or prompts
1288
+ if pixel_values is not None:
1289
+ inputs_embeds = torch.cat([input_vis, inputs_lang], dim=1)
1290
+ attention_mask = torch.cat([vis_atts, lang_atts], dim=1)
1291
+ else:
1292
+ inputs_embeds = inputs_lang
1293
+ attention_mask = lang_atts
1294
+
1295
+ outputs = self.language_model.generate(
1296
+ inputs_embeds=inputs_embeds,
1297
+ attention_mask=attention_mask,
1298
+ **generate_kwargs,
1299
+ )
1300
+ return outputs
processing_chexagent.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The CheXagent Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import List, Optional, Union
17
+
18
+ import torch
19
+ from transformers.image_utils import ImageInput
20
+ from transformers.processing_utils import ProcessorMixin
21
+ from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput
22
+ from transformers.tokenization_utils_base import TruncationStrategy
23
+ from transformers.utils import TensorType
24
+
25
+
26
+ class CheXagentProcessor(ProcessorMixin):
27
+ attributes = ["image_processor", "tokenizer"]
28
+ image_processor_class = "BlipImageProcessor"
29
+ tokenizer_class = "AutoTokenizer"
30
+
31
+ def __init__(self, image_processor, tokenizer):
32
+ tokenizer.return_token_type_ids = False
33
+ super().__init__(image_processor, tokenizer)
34
+ self.current_processor = self.image_processor
35
+
36
+ def __call__(
37
+ self,
38
+ images: ImageInput = None,
39
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
40
+ add_special_tokens: bool = True,
41
+ padding: Union[bool, str, PaddingStrategy] = False,
42
+ truncation: Union[bool, str, TruncationStrategy] = None,
43
+ max_length: Optional[int] = None,
44
+ stride: int = 0,
45
+ pad_to_multiple_of: Optional[int] = None,
46
+ return_attention_mask: Optional[bool] = None,
47
+ return_overflowing_tokens: bool = False,
48
+ return_special_tokens_mask: bool = False,
49
+ return_offsets_mapping: bool = False,
50
+ return_token_type_ids: bool = False,
51
+ return_length: bool = False,
52
+ verbose: bool = True,
53
+ return_tensors: Optional[Union[str, TensorType]] = None,
54
+ **kwargs,
55
+ ) -> BatchEncoding:
56
+ if images is None and text is None:
57
+ raise ValueError("You have to specify either images or text.")
58
+
59
+ # Get only text
60
+ if images is None:
61
+ self.current_processor = self.tokenizer
62
+ text_encoding = self.tokenizer(
63
+ text=text,
64
+ add_special_tokens=add_special_tokens,
65
+ padding=padding,
66
+ truncation=truncation,
67
+ max_length=max_length,
68
+ stride=stride,
69
+ pad_to_multiple_of=pad_to_multiple_of,
70
+ return_attention_mask=return_attention_mask,
71
+ return_overflowing_tokens=return_overflowing_tokens,
72
+ return_special_tokens_mask=return_special_tokens_mask,
73
+ return_offsets_mapping=return_offsets_mapping,
74
+ return_token_type_ids=return_token_type_ids,
75
+ return_length=return_length,
76
+ verbose=verbose,
77
+ return_tensors=return_tensors,
78
+ **kwargs,
79
+ )
80
+ return text_encoding
81
+
82
+ # add pixel_values
83
+ if images is not None:
84
+ encoding_image_processor = self.image_processor(images, return_tensors=return_tensors)
85
+ encoding_image_processor["pixel_values"] = torch.stack(
86
+ [torch.tensor(pixel_values) for pixel_values in encoding_image_processor["pixel_values"]]
87
+ ).unsqueeze(0)
88
+
89
+ if text is not None:
90
+ text_encoding = self.tokenizer(
91
+ text=text,
92
+ add_special_tokens=add_special_tokens,
93
+ padding=padding,
94
+ truncation=truncation,
95
+ max_length=max_length,
96
+ stride=stride,
97
+ pad_to_multiple_of=pad_to_multiple_of,
98
+ return_attention_mask=return_attention_mask,
99
+ return_overflowing_tokens=return_overflowing_tokens,
100
+ return_special_tokens_mask=return_special_tokens_mask,
101
+ return_offsets_mapping=return_offsets_mapping,
102
+ return_token_type_ids=return_token_type_ids,
103
+ return_length=return_length,
104
+ verbose=verbose,
105
+ return_tensors=return_tensors,
106
+ **kwargs,
107
+ )
108
+ else:
109
+ text_encoding = None
110
+
111
+ if text_encoding is not None:
112
+ encoding_image_processor.update(text_encoding)
113
+
114
+ return encoding_image_processor
115
+
116
+ def batch_decode(self, *args, **kwargs):
117
+ return self.tokenizer.batch_decode(*args, **kwargs)
118
+
119
+ def decode(self, *args, **kwargs):
120
+ return self.tokenizer.decode(*args, **kwargs)
121
+
122
+ @property
123
+ def model_input_names(self):
124
+ tokenizer_input_names = self.tokenizer.model_input_names
125
+ image_processor_input_names = self.image_processor.model_input_names
126
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))