gxy commited on
Commit
f8b5be2
1 Parent(s): 52b6907

FEAT: modify language init method

Browse files
Files changed (2) hide show
  1. config.json +5 -1
  2. modeling_ziya_blip2.py +1 -7
config.json CHANGED
@@ -1,12 +1,16 @@
1
  {
2
  "architectures": [
3
- "ZiyaBLIP2ForConditionalGeneration"
4
  ],
5
  "assistant_name": "<bot>",
6
  "human_name": "<human>",
7
  "initializer_factor": 1.0,
8
  "initializer_range": 0.02,
9
  "model_type": "blip-2",
 
 
 
 
10
  "num_query_tokens": 32,
11
  "prompt_prefix": "",
12
  "qformer_config": {
 
1
  {
2
  "architectures": [
3
+ "ZiyaBlip2ForCausalLM"
4
  ],
5
  "assistant_name": "<bot>",
6
  "human_name": "<human>",
7
  "initializer_factor": 1.0,
8
  "initializer_range": 0.02,
9
  "model_type": "blip-2",
10
+ "auto_map": {
11
+ "AutoModel": "modeling_ziya_blip2.ZiyaBlip2ForCausalLM",
12
+ "AutoModelForCausalLM": "modeling_ziya_blip2.ZiyaBlip2ForCausalLM"
13
+ },
14
  "num_query_tokens": 32,
15
  "prompt_prefix": "",
16
  "qformer_config": {
modeling_ziya_blip2.py CHANGED
@@ -11,7 +11,6 @@ from transformers.models.blip_2.modeling_blip_2 import Blip2ForConditionalGenera
11
  from transformers import (
12
  Blip2PreTrainedModel,
13
  Blip2VisionModel,
14
- AutoModelForCausalLM,
15
  Blip2QFormerModel,
16
  PreTrainedTokenizer,
17
  PreTrainedModel,
@@ -21,7 +20,7 @@ from transformers import (
21
  logger = logging.get_logger(__name__)
22
 
23
 
24
- class ZiyaBLIP2ForConditionalGeneration(Blip2PreTrainedModel):
25
  config_class = Blip2Config
26
  main_input_name = "pixel_values"
27
  _keys_to_ignore_on_load_missing = [
@@ -38,11 +37,6 @@ class ZiyaBLIP2ForConditionalGeneration(Blip2PreTrainedModel):
38
 
39
  self.language_projection = nn.Linear(
40
  config.qformer_config.hidden_size, config.text_config.hidden_size)
41
- if language_model is None:
42
- if config.use_decoder_only_language_model:
43
- language_model = AutoModelForCausalLM.from_config(config.text_config)
44
- else:
45
- raise Exception("not impl")
46
  self.language_model = language_model
47
 
48
  # Initialize weights and apply final processing
 
11
  from transformers import (
12
  Blip2PreTrainedModel,
13
  Blip2VisionModel,
 
14
  Blip2QFormerModel,
15
  PreTrainedTokenizer,
16
  PreTrainedModel,
 
20
  logger = logging.get_logger(__name__)
21
 
22
 
23
+ class ZiyaBlip2ForCausalLM(Blip2PreTrainedModel):
24
  config_class = Blip2Config
25
  main_input_name = "pixel_values"
26
  _keys_to_ignore_on_load_missing = [
 
37
 
38
  self.language_projection = nn.Linear(
39
  config.qformer_config.hidden_size, config.text_config.hidden_size)
 
 
 
 
 
40
  self.language_model = language_model
41
 
42
  # Initialize weights and apply final processing