zR
commited on
Commit
•
8ba56be
1
Parent(s):
7031540
fix #350
Browse files
visual.py
CHANGED
@@ -6,6 +6,7 @@ from transformers.activations import ACT2FN
|
|
6 |
import math
|
7 |
from torch.nn import LayerNorm
|
8 |
|
|
|
9 |
def standard_attention(query_layer, key_layer, value_layer, scaling_attention_score=True):
|
10 |
if scaling_attention_score:
|
11 |
query_layer = query_layer / math.sqrt(query_layer.shape[-1])
|
@@ -16,11 +17,12 @@ def standard_attention(query_layer, key_layer, value_layer, scaling_attention_sc
|
|
16 |
context_layer = torch.matmul(attention_probs, value_layer)
|
17 |
return context_layer
|
18 |
|
|
|
19 |
def attention_fn_default(query_layer, key_layer, value_layer, scaling_attention_score=True):
|
20 |
if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score:
|
21 |
# Pytorch 2.0 attention uses very much memory if attention_mask is float, and has NaN bug if attention_mask is None.
|
22 |
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
23 |
-
query_layer, key_layer, value_layer,
|
24 |
attn_mask=None,
|
25 |
dropout_p=0.,
|
26 |
is_causal=False
|
@@ -31,10 +33,12 @@ def attention_fn_default(query_layer, key_layer, value_layer, scaling_attention_
|
|
31 |
query_layer, key_layer, value_layer, scaling_attention_score=scaling_attention_score
|
32 |
)
|
33 |
|
|
|
34 |
class PatchEmbedding(nn.Module):
|
35 |
def __init__(self, config):
|
36 |
super().__init__()
|
37 |
-
self.proj = nn.Conv2d(config.in_channels, config.hidden_size, kernel_size=config.patch_size,
|
|
|
38 |
self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
|
39 |
self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size)
|
40 |
|
@@ -62,11 +66,11 @@ class Attention(nn.Module):
|
|
62 |
qkv = self.query_key_value(x)
|
63 |
qkv = qkv.reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, H, L, D
|
64 |
q, k, v = qkv[0], qkv[1], qkv[2]
|
65 |
-
|
66 |
out = attention_fn_default(
|
67 |
q, k, v
|
68 |
)
|
69 |
-
output = self.dense(out.transpose(1, 2).
|
70 |
output = self.output_dropout(output)
|
71 |
return output
|
72 |
|
@@ -105,7 +109,9 @@ class TransformerLayer(nn.Module):
|
|
105 |
attention_output = self.input_layernorm(self.attention(attention_input))
|
106 |
hidden_states = attention_input + attention_output
|
107 |
mlp_input = hidden_states
|
108 |
-
|
|
|
|
|
109 |
output = mlp_input + mlp_output
|
110 |
return output
|
111 |
|
@@ -147,7 +153,8 @@ class EVA2CLIPModel(nn.Module):
|
|
147 |
self.patch_embedding = PatchEmbedding(vision_config)
|
148 |
self.transformer = Transformer(vision_config)
|
149 |
self.linear_proj = GLU(config, in_features=config.hidden_size)
|
150 |
-
self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, out_channels=config.hidden_size, kernel_size=2,
|
|
|
151 |
self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
152 |
self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
153 |
self.scaling_factor = vision_config.scaling_factor
|
@@ -158,14 +165,16 @@ class EVA2CLIPModel(nn.Module):
|
|
158 |
x = x[:, 1:]
|
159 |
|
160 |
b, s, h = x.shape
|
161 |
-
grid_size = int(s**0.5)
|
162 |
x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
|
163 |
x = self.conv(x)
|
164 |
|
165 |
x = x.flatten(2).transpose(1, 2)
|
166 |
x = self.linear_proj(x)
|
167 |
-
|
168 |
-
|
|
|
|
|
169 |
x = torch.cat((boi, x, eoi), dim=1)
|
170 |
x = x / self.scaling_factor
|
171 |
return x
|
|
|
6 |
import math
|
7 |
from torch.nn import LayerNorm
|
8 |
|
9 |
+
|
10 |
def standard_attention(query_layer, key_layer, value_layer, scaling_attention_score=True):
|
11 |
if scaling_attention_score:
|
12 |
query_layer = query_layer / math.sqrt(query_layer.shape[-1])
|
|
|
17 |
context_layer = torch.matmul(attention_probs, value_layer)
|
18 |
return context_layer
|
19 |
|
20 |
+
|
21 |
def attention_fn_default(query_layer, key_layer, value_layer, scaling_attention_score=True):
|
22 |
if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score:
|
23 |
# Pytorch 2.0 attention uses very much memory if attention_mask is float, and has NaN bug if attention_mask is None.
|
24 |
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
25 |
+
query_layer, key_layer, value_layer,
|
26 |
attn_mask=None,
|
27 |
dropout_p=0.,
|
28 |
is_causal=False
|
|
|
33 |
query_layer, key_layer, value_layer, scaling_attention_score=scaling_attention_score
|
34 |
)
|
35 |
|
36 |
+
|
37 |
class PatchEmbedding(nn.Module):
|
38 |
def __init__(self, config):
|
39 |
super().__init__()
|
40 |
+
self.proj = nn.Conv2d(config.in_channels, config.hidden_size, kernel_size=config.patch_size,
|
41 |
+
stride=config.patch_size)
|
42 |
self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
|
43 |
self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size)
|
44 |
|
|
|
66 |
qkv = self.query_key_value(x)
|
67 |
qkv = qkv.reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, H, L, D
|
68 |
q, k, v = qkv[0], qkv[1], qkv[2]
|
69 |
+
|
70 |
out = attention_fn_default(
|
71 |
q, k, v
|
72 |
)
|
73 |
+
output = self.dense(out.transpose(1, 2).view(B, L, -1))
|
74 |
output = self.output_dropout(output)
|
75 |
return output
|
76 |
|
|
|
109 |
attention_output = self.input_layernorm(self.attention(attention_input))
|
110 |
hidden_states = attention_input + attention_output
|
111 |
mlp_input = hidden_states
|
112 |
+
|
113 |
+
# https://github.com/THUDM/GLM-4/issues/350
|
114 |
+
mlp_output = self.post_attention_layernorm(self.mlp(mlp_input)).to(mlp_input.device)
|
115 |
output = mlp_input + mlp_output
|
116 |
return output
|
117 |
|
|
|
153 |
self.patch_embedding = PatchEmbedding(vision_config)
|
154 |
self.transformer = Transformer(vision_config)
|
155 |
self.linear_proj = GLU(config, in_features=config.hidden_size)
|
156 |
+
self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, out_channels=config.hidden_size, kernel_size=2,
|
157 |
+
stride=2)
|
158 |
self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
159 |
self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
160 |
self.scaling_factor = vision_config.scaling_factor
|
|
|
165 |
x = x[:, 1:]
|
166 |
|
167 |
b, s, h = x.shape
|
168 |
+
grid_size = int(s ** 0.5)
|
169 |
x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
|
170 |
x = self.conv(x)
|
171 |
|
172 |
x = x.flatten(2).transpose(1, 2)
|
173 |
x = self.linear_proj(x)
|
174 |
+
|
175 |
+
# https://github.com/THUDM/GLM-4/issues/350
|
176 |
+
boi = self.boi.expand(x.shape[0], -1, -1).to(x.device)
|
177 |
+
eoi = self.eoi.expand(x.shape[0], -1, -1).to(x.device)
|
178 |
x = torch.cat((boi, x, eoi), dim=1)
|
179 |
x = x / self.scaling_factor
|
180 |
return x
|