fix image size unchangeable

#32
README.md CHANGED
@@ -19,8 +19,6 @@ inference: false
19
 
20
  Read this in [English](README_en.md)
21
 
22
- **2024/08/12, 本仓库代码已更新并使用 `transforemrs>=4.44.0`, 请及时更新依赖。**
23
-
24
  GLM-4V-9B 是智谱 AI 推出的最新一代预训练模型 GLM-4 系列中的开源多模态版本。
25
  **GLM-4V-9B** 具备 1120 * 1120 高分辨率下的中英双语多轮对话能力,在中英文综合能力、感知推理、文字识别、图表理解等多方面多模态评测中,GLM-4V-9B 表现出超越 GPT-4-turbo-2024-04-09、Gemini
26
  1.0 Pro、Qwen-VL-Max 和 Claude 3 Opus 的卓越性能。
@@ -48,10 +46,7 @@ GLM-4V-9B 是一个多模态语言模型,具备视觉理解能力,其相关
48
 
49
  ## 运行模型
50
 
51
- **更多推理代码和依赖信息,请访问我们的 [github](https://github.com/THUDM/GLM-4)。**
52
-
53
- **请严格按照[依赖](https://github.com/THUDM/GLM-4/blob/main/basic_demo/requirements.txt)安装,否则无法正常运行。**
54
-
55
 
56
  ```python
57
  import torch
 
19
 
20
  Read this in [English](README_en.md)
21
 
 
 
22
  GLM-4V-9B 是智谱 AI 推出的最新一代预训练模型 GLM-4 系列中的开源多模态版本。
23
  **GLM-4V-9B** 具备 1120 * 1120 高分辨率下的中英双语多轮对话能力,在中英文综合能力、感知推理、文字识别、图表理解等多方面多模态评测中,GLM-4V-9B 表现出超越 GPT-4-turbo-2024-04-09、Gemini
24
  1.0 Pro、Qwen-VL-Max 和 Claude 3 Opus 的卓越性能。
 
46
 
47
  ## 运行模型
48
 
49
+ 更多推理代码和依赖信息,请访问我们的 [github](https://github.com/THUDM/GLM-4)
 
 
 
50
 
51
  ```python
52
  import torch
README_en.md CHANGED
@@ -1,7 +1,5 @@
1
  # GLM-4V-9B
2
 
3
- **2024/08/12, The repository code has been updated and now requires `transformers>=4.44.0`. Please update your dependencies accordingly.**
4
-
5
  GLM-4V-9B is an open source multimodal version of the latest generation of pre-trained models in the GLM-4 series launched by Zhipu AI.
6
  **GLM-4V-9B** has the ability to conduct multi-round conversations in Chinese and English at a high resolution of 1120 * 1120. In multimodal evaluations of comprehensive Chinese and English abilities, perceptual reasoning, text recognition, and chart understanding, GLM-4V-9B has shown superior performance over GPT-4-turbo-2024-04-09, Gemini
7
  1.0 Pro, Qwen-VL-Max, and Claude 3 Opus.
@@ -31,9 +29,7 @@ GLM-4V-9B is a multimodal language model with visual understanding capabilities.
31
 
32
  ## Quick Start
33
 
34
- **For more inference code and requirements, please visit our [github page](https://github.com/THUDM/GLM-4).**
35
-
36
- **Please strictly follow the [dependencies](https://github.com/THUDM/GLM-4/blob/main/basic_demo/requirements.txt) to install, otherwise it will not run properly**
37
 
38
 
39
  ```python
 
1
  # GLM-4V-9B
2
 
 
 
3
  GLM-4V-9B is an open source multimodal version of the latest generation of pre-trained models in the GLM-4 series launched by Zhipu AI.
4
  **GLM-4V-9B** has the ability to conduct multi-round conversations in Chinese and English at a high resolution of 1120 * 1120. In multimodal evaluations of comprehensive Chinese and English abilities, perceptual reasoning, text recognition, and chart understanding, GLM-4V-9B has shown superior performance over GPT-4-turbo-2024-04-09, Gemini
5
  1.0 Pro, Qwen-VL-Max, and Claude 3 Opus.
 
29
 
30
  ## Quick Start
31
 
32
+ For more inference code and requirements, please visit our [github page](https://github.com/THUDM/GLM-4).
 
 
33
 
34
 
35
  ```python
config.json CHANGED
@@ -50,7 +50,7 @@
50
  "seq_length": 8192,
51
  "use_cache": true,
52
  "torch_dtype": "bfloat16",
53
- "transformers_version": "4.44.0",
54
  "tie_word_embeddings": false,
55
  "eos_token_id": [151329, 151336, 151338],
56
  "pad_token_id": 151329,
 
50
  "seq_length": 8192,
51
  "use_cache": true,
52
  "torch_dtype": "bfloat16",
53
+ "transformers_version": "4.40.2",
54
  "tie_word_embeddings": false,
55
  "eos_token_id": [151329, 151336, 151338],
56
  "pad_token_id": 151329,
generation_config.json CHANGED
@@ -9,5 +9,5 @@
9
  "temperature": 0.8,
10
  "max_length": 8192,
11
  "top_p": 0.8,
12
- "transformers_version": "4.44.0"
13
  }
 
9
  "temperature": 0.8,
10
  "max_length": 8192,
11
  "top_p": 0.8,
12
+ "transformers_version": "4.40.2"
13
  }
modeling_chatglm.py CHANGED
@@ -1,13 +1,18 @@
1
- """ PyTorch GLM-4V model. """
 
2
  import math
 
 
3
  import sys
 
4
  import torch
5
  import torch.utils.checkpoint
6
  import torch.nn.functional as F
7
  from torch import nn
8
  from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
9
  from torch.nn.utils import skip_init
10
- from typing import Optional, Tuple, Union, List, Dict, Any
 
11
 
12
  from transformers.modeling_outputs import (
13
  BaseModelOutputWithPast,
@@ -848,6 +853,11 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
848
  batch_size, seq_length = input_ids.shape
849
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
850
 
 
 
 
 
 
851
  class Embedding(torch.nn.Module):
852
  """Language model embeddings."""
853
 
@@ -1082,10 +1092,12 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1082
  outputs: ModelOutput,
1083
  model_kwargs: Dict[str, Any],
1084
  is_encoder_decoder: bool = False,
 
1085
  ) -> Dict[str, Any]:
1086
  # update past_key_values
1087
- cache_name, cache = self._extract_past_from_model_output(outputs)
1088
- model_kwargs[cache_name] = cache
 
1089
 
1090
  # update attention mask
1091
  if "attention_mask" in model_kwargs:
@@ -1192,6 +1204,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1192
 
1193
  loss = None
1194
  if labels is not None:
 
1195
  new_labels = []
1196
  for i in range(len(input_ids)):
1197
  input_id = input_ids[i].tolist()
@@ -1203,12 +1216,16 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1203
  (
1204
  labels[i, :boi_token_pos + 1],
1205
  torch.tensor([-100]).to(labels.device).to(labels.dtype).repeat(1600),
1206
- labels[i, eoi_token_pos:])))
1207
 
1208
  labels = torch.stack(new_labels, dim=0)
 
1209
  lm_logits = lm_logits.to(torch.float32)
 
 
1210
  shift_logits = lm_logits[..., :-1, :].contiguous()
1211
  shift_labels = labels[..., 1:].contiguous()
 
1212
  loss_fct = CrossEntropyLoss(ignore_index=-100)
1213
  loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1214
 
@@ -1246,6 +1263,210 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1246
  for layer_past in past
1247
  )
1248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1249
  class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1250
  def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
1251
  super().__init__(config)
 
1
+ """ PyTorch ChatGLM model. """
2
+ import json
3
  import math
4
+ import copy
5
+ import warnings
6
  import sys
7
+
8
  import torch
9
  import torch.utils.checkpoint
10
  import torch.nn.functional as F
11
  from torch import nn
12
  from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
13
  from torch.nn.utils import skip_init
14
+ from typing import Optional, Tuple, Union, List, Callable, Dict, Any
15
+ from copy import deepcopy
16
 
17
  from transformers.modeling_outputs import (
18
  BaseModelOutputWithPast,
 
853
  batch_size, seq_length = input_ids.shape
854
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
855
 
856
+ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
857
+ if not self.supports_gradient_checkpointing:
858
+ raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
859
+
860
+
861
  class Embedding(torch.nn.Module):
862
  """Language model embeddings."""
863
 
 
1092
  outputs: ModelOutput,
1093
  model_kwargs: Dict[str, Any],
1094
  is_encoder_decoder: bool = False,
1095
+ standardize_cache_format: bool = False,
1096
  ) -> Dict[str, Any]:
1097
  # update past_key_values
1098
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
1099
+ outputs, standardize_cache_format=standardize_cache_format
1100
+ )
1101
 
1102
  # update attention mask
1103
  if "attention_mask" in model_kwargs:
 
1204
 
1205
  loss = None
1206
  if labels is not None:
1207
+ # https://github.com/THUDM/GLM-4/issues/264
1208
  new_labels = []
1209
  for i in range(len(input_ids)):
1210
  input_id = input_ids[i].tolist()
 
1216
  (
1217
  labels[i, :boi_token_pos + 1],
1218
  torch.tensor([-100]).to(labels.device).to(labels.dtype).repeat(1600),
1219
+ labels[i, eoi_token_pos:]))) # 在两个token之间加入
1220
 
1221
  labels = torch.stack(new_labels, dim=0)
1222
+
1223
  lm_logits = lm_logits.to(torch.float32)
1224
+
1225
+ # Shift so that tokens < n predict n
1226
  shift_logits = lm_logits[..., :-1, :].contiguous()
1227
  shift_labels = labels[..., 1:].contiguous()
1228
+ # Flatten the tokens
1229
  loss_fct = CrossEntropyLoss(ignore_index=-100)
1230
  loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1231
 
 
1263
  for layer_past in past
1264
  )
1265
 
1266
+ def process_response(self, output, history):
1267
+ content = ""
1268
+ history = deepcopy(history)
1269
+ for response in output.split("<|assistant|>"):
1270
+ if "\n" in response:
1271
+ metadata, content = response.split("\n", maxsplit=1)
1272
+ else:
1273
+ metadata, content = "", response
1274
+ if not metadata.strip():
1275
+ content = content.strip()
1276
+ history.append({"role": "assistant", "metadata": metadata, "content": content})
1277
+ content = content.replace("[[训练时间]]", "2023年")
1278
+ else:
1279
+ history.append({"role": "assistant", "metadata": metadata, "content": content})
1280
+ if history[0]["role"] == "system" and "tools" in history[0]:
1281
+ parameters = json.loads(content)
1282
+ content = {"name": metadata.strip(), "parameters": parameters}
1283
+ else:
1284
+ content = {"name": metadata.strip(), "content": content}
1285
+ return content, history
1286
+
1287
+ @torch.inference_mode()
1288
+ def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user", image=None,
1289
+ max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
1290
+ **kwargs):
1291
+ if history is None:
1292
+ history = []
1293
+ if logits_processor is None:
1294
+ logits_processor = LogitsProcessorList()
1295
+ logits_processor.append(InvalidScoreLogitsProcessor())
1296
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1297
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1298
+ message = {"role": role, "content": query}
1299
+ if image is not None:
1300
+ message["image"] = image
1301
+ history.append(message)
1302
+ inputs = tokenizer.apply_chat_template(history, add_generation_prompt=True, tokenize=True,
1303
+ return_tensors="pt", return_dict=True)
1304
+ inputs = inputs.to(self.device)
1305
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|user|>"),
1306
+ tokenizer.convert_tokens_to_ids("<|observation|>")]
1307
+ outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
1308
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1309
+ response = tokenizer.decode(outputs)
1310
+ response, history = self.process_response(response, history)
1311
+ return response, history
1312
+
1313
+ @torch.inference_mode()
1314
+ def stream_chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user", image=None,
1315
+ past_key_values=None, max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
1316
+ logits_processor=None, return_past_key_values=False, **kwargs):
1317
+ if history is None:
1318
+ history = []
1319
+ if logits_processor is None:
1320
+ logits_processor = LogitsProcessorList()
1321
+ logits_processor.append(InvalidScoreLogitsProcessor())
1322
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|user|>"),
1323
+ tokenizer.convert_tokens_to_ids("<|observation|>")]
1324
+ gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
1325
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1326
+ message = {"role": role, "content": "query"}
1327
+ if image is not None:
1328
+ message["image"] = image
1329
+ if past_key_values is None:
1330
+ inputs = tokenizer.apply_chat_template(history + [message],
1331
+ add_generation_prompt=True, tokenize=True, return_tensors="pt",
1332
+ return_dict=True)
1333
+ else:
1334
+ inputs = tokenizer.apply_chat_template([message], add_special_tokens=False,
1335
+ add_generation_prompt=True, tokenize=True, return_tensors="pt",
1336
+ return_dict=True)
1337
+ inputs = inputs.to(self.device)
1338
+ if past_key_values is not None:
1339
+ past_length = past_key_values[0][0].shape[2]
1340
+ if self.transformer.pre_seq_len is not None:
1341
+ past_length -= self.transformer.pre_seq_len
1342
+ inputs.position_ids += past_length
1343
+ attention_mask = inputs.attention_mask
1344
+ attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
1345
+ inputs['attention_mask'] = attention_mask
1346
+ history.append({"role": role, "content": query})
1347
+ for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
1348
+ eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
1349
+ **gen_kwargs):
1350
+ if return_past_key_values:
1351
+ outputs, past_key_values = outputs
1352
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1353
+ response = tokenizer.decode(outputs)
1354
+ if response and response[-1] != "�":
1355
+ response, new_history = self.process_response(response, history)
1356
+ if return_past_key_values:
1357
+ yield response, new_history, past_key_values
1358
+ else:
1359
+ yield response, new_history
1360
+
1361
+ @torch.inference_mode()
1362
+ def stream_generate(
1363
+ self,
1364
+ input_ids,
1365
+ generation_config: Optional[GenerationConfig] = None,
1366
+ logits_processor: Optional[LogitsProcessorList] = None,
1367
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
1368
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
1369
+ return_past_key_values=False,
1370
+ **kwargs,
1371
+ ):
1372
+ batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
1373
+
1374
+ if generation_config is None:
1375
+ generation_config = self.generation_config
1376
+ generation_config = copy.deepcopy(generation_config)
1377
+ model_kwargs = generation_config.update(**kwargs)
1378
+ model_kwargs["use_cache"] = generation_config.use_cache
1379
+ bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
1380
+
1381
+ if isinstance(eos_token_id, int):
1382
+ eos_token_id = [eos_token_id]
1383
+ eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
1384
+
1385
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
1386
+ if has_default_max_length and generation_config.max_new_tokens is None:
1387
+ warnings.warn(
1388
+ f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
1389
+ "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
1390
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
1391
+ UserWarning,
1392
+ )
1393
+ elif generation_config.max_new_tokens is not None:
1394
+ generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
1395
+ if not has_default_max_length:
1396
+ logger.warn(
1397
+ f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
1398
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
1399
+ "Please refer to the documentation for more information. "
1400
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
1401
+ UserWarning,
1402
+ )
1403
+
1404
+ if input_ids_seq_length >= generation_config.max_length:
1405
+ input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
1406
+ logger.warning(
1407
+ f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
1408
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
1409
+ " increasing `max_new_tokens`."
1410
+ )
1411
+
1412
+ # 2. Set generation parameters if not already defined
1413
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1414
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1415
+
1416
+ logits_processor = self._get_logits_processor(
1417
+ generation_config=generation_config,
1418
+ input_ids_seq_length=input_ids_seq_length,
1419
+ encoder_input_ids=input_ids,
1420
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1421
+ logits_processor=logits_processor,
1422
+ )
1423
+
1424
+ stopping_criteria = self._get_stopping_criteria(
1425
+ generation_config=generation_config, stopping_criteria=stopping_criteria
1426
+ )
1427
+ logits_warper = self._get_logits_warper(generation_config)
1428
+
1429
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
1430
+ scores = None
1431
+ while True:
1432
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1433
+ # forward pass to get next token
1434
+ outputs = self(
1435
+ **model_inputs,
1436
+ return_dict=True,
1437
+ output_attentions=False,
1438
+ output_hidden_states=False,
1439
+ )
1440
+
1441
+ next_token_logits = outputs.logits[:, -1, :]
1442
+
1443
+ # pre-process distribution
1444
+ next_token_scores = logits_processor(input_ids, next_token_logits)
1445
+ next_token_scores = logits_warper(input_ids, next_token_scores)
1446
+
1447
+ # sample
1448
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
1449
+ if generation_config.do_sample:
1450
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1451
+ else:
1452
+ next_tokens = torch.argmax(probs, dim=-1)
1453
+ # update generated ids, model inputs, and length for next step
1454
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
1455
+ model_kwargs = self._update_model_kwargs_for_generation(
1456
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1457
+ )
1458
+ unfinished_sequences = unfinished_sequences.mul(
1459
+ next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
1460
+ )
1461
+ if return_past_key_values:
1462
+ yield input_ids, outputs.past_key_values
1463
+ else:
1464
+ yield input_ids
1465
+ # stop when each sentence is finished, or if we exceed the maximum length
1466
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
1467
+ break
1468
+
1469
+
1470
  class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1471
  def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
1472
  super().__init__(config)
tokenization_chatglm.py CHANGED
@@ -303,7 +303,6 @@ class ChatGLM4Tokenizer(PreTrainedTokenizer):
303
  max_length: Optional[int] = None,
304
  padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
305
  pad_to_multiple_of: Optional[int] = None,
306
- padding_side: Optional[str] = None,
307
  return_attention_mask: Optional[bool] = None,
308
  ) -> dict:
309
  """
 
303
  max_length: Optional[int] = None,
304
  padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
305
  pad_to_multiple_of: Optional[int] = None,
 
306
  return_attention_mask: Optional[bool] = None,
307
  ) -> dict:
308
  """
visual.py CHANGED
@@ -6,7 +6,6 @@ from transformers.activations import ACT2FN
6
  import math
7
  from torch.nn import LayerNorm
8
 
9
-
10
  def standard_attention(query_layer, key_layer, value_layer, scaling_attention_score=True):
11
  if scaling_attention_score:
12
  query_layer = query_layer / math.sqrt(query_layer.shape[-1])
@@ -17,12 +16,11 @@ def standard_attention(query_layer, key_layer, value_layer, scaling_attention_sc
17
  context_layer = torch.matmul(attention_probs, value_layer)
18
  return context_layer
19
 
20
-
21
  def attention_fn_default(query_layer, key_layer, value_layer, scaling_attention_score=True):
22
  if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score:
23
  # Pytorch 2.0 attention uses very much memory if attention_mask is float, and has NaN bug if attention_mask is None.
24
  attn_output = torch.nn.functional.scaled_dot_product_attention(
25
- query_layer, key_layer, value_layer,
26
  attn_mask=None,
27
  dropout_p=0.,
28
  is_causal=False
@@ -33,12 +31,10 @@ def attention_fn_default(query_layer, key_layer, value_layer, scaling_attention_
33
  query_layer, key_layer, value_layer, scaling_attention_score=scaling_attention_score
34
  )
35
 
36
-
37
  class PatchEmbedding(nn.Module):
38
  def __init__(self, config):
39
  super().__init__()
40
- self.proj = nn.Conv2d(config.in_channels, config.hidden_size, kernel_size=config.patch_size,
41
- stride=config.patch_size)
42
  self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
43
  self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size)
44
 
@@ -66,7 +62,7 @@ class Attention(nn.Module):
66
  qkv = self.query_key_value(x)
67
  qkv = qkv.reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, H, L, D
68
  q, k, v = qkv[0], qkv[1], qkv[2]
69
-
70
  out = attention_fn_default(
71
  q, k, v
72
  )
@@ -109,9 +105,7 @@ class TransformerLayer(nn.Module):
109
  attention_output = self.input_layernorm(self.attention(attention_input))
110
  hidden_states = attention_input + attention_output
111
  mlp_input = hidden_states
112
-
113
- # https://github.com/THUDM/GLM-4/issues/350
114
- mlp_output = self.post_attention_layernorm(self.mlp(mlp_input)).to(mlp_input.device)
115
  output = mlp_input + mlp_output
116
  return output
117
 
@@ -153,8 +147,7 @@ class EVA2CLIPModel(nn.Module):
153
  self.patch_embedding = PatchEmbedding(vision_config)
154
  self.transformer = Transformer(vision_config)
155
  self.linear_proj = GLU(config, in_features=config.hidden_size)
156
- self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, out_channels=config.hidden_size, kernel_size=2,
157
- stride=2)
158
  self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
159
  self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
160
  self.scaling_factor = vision_config.scaling_factor
@@ -165,16 +158,14 @@ class EVA2CLIPModel(nn.Module):
165
  x = x[:, 1:]
166
 
167
  b, s, h = x.shape
168
- grid_size = int(s ** 0.5)
169
  x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
170
  x = self.conv(x)
171
 
172
  x = x.flatten(2).transpose(1, 2)
173
  x = self.linear_proj(x)
174
-
175
- # https://github.com/THUDM/GLM-4/issues/350
176
- boi = self.boi.expand(x.shape[0], -1, -1).to(x.device)
177
- eoi = self.eoi.expand(x.shape[0], -1, -1).to(x.device)
178
  x = torch.cat((boi, x, eoi), dim=1)
179
  x = x / self.scaling_factor
180
  return x
 
6
  import math
7
  from torch.nn import LayerNorm
8
 
 
9
  def standard_attention(query_layer, key_layer, value_layer, scaling_attention_score=True):
10
  if scaling_attention_score:
11
  query_layer = query_layer / math.sqrt(query_layer.shape[-1])
 
16
  context_layer = torch.matmul(attention_probs, value_layer)
17
  return context_layer
18
 
 
19
  def attention_fn_default(query_layer, key_layer, value_layer, scaling_attention_score=True):
20
  if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score:
21
  # Pytorch 2.0 attention uses very much memory if attention_mask is float, and has NaN bug if attention_mask is None.
22
  attn_output = torch.nn.functional.scaled_dot_product_attention(
23
+ query_layer, key_layer, value_layer,
24
  attn_mask=None,
25
  dropout_p=0.,
26
  is_causal=False
 
31
  query_layer, key_layer, value_layer, scaling_attention_score=scaling_attention_score
32
  )
33
 
 
34
  class PatchEmbedding(nn.Module):
35
  def __init__(self, config):
36
  super().__init__()
37
+ self.proj = nn.Conv2d(config.in_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size)
 
38
  self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
39
  self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size)
40
 
 
62
  qkv = self.query_key_value(x)
63
  qkv = qkv.reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, H, L, D
64
  q, k, v = qkv[0], qkv[1], qkv[2]
65
+
66
  out = attention_fn_default(
67
  q, k, v
68
  )
 
105
  attention_output = self.input_layernorm(self.attention(attention_input))
106
  hidden_states = attention_input + attention_output
107
  mlp_input = hidden_states
108
+ mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
 
 
109
  output = mlp_input + mlp_output
110
  return output
111
 
 
147
  self.patch_embedding = PatchEmbedding(vision_config)
148
  self.transformer = Transformer(vision_config)
149
  self.linear_proj = GLU(config, in_features=config.hidden_size)
150
+ self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, out_channels=config.hidden_size, kernel_size=2, stride=2)
 
151
  self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
152
  self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
153
  self.scaling_factor = vision_config.scaling_factor
 
158
  x = x[:, 1:]
159
 
160
  b, s, h = x.shape
161
+ grid_size = int(s**0.5)
162
  x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
163
  x = self.conv(x)
164
 
165
  x = x.flatten(2).transpose(1, 2)
166
  x = self.linear_proj(x)
167
+ boi = self.boi.expand(x.shape[0], -1, -1)
168
+ eoi = self.eoi.expand(x.shape[0], -1, -1)
 
 
169
  x = torch.cat((boi, x, eoi), dim=1)
170
  x = x / self.scaling_factor
171
  return x