amitha commited on
Commit
650d074
1 Parent(s): 94c499a

Upload LlavaBaichuanForCausalLM

Browse files
Files changed (2) hide show
  1. config.json +3 -2
  2. llava_baichuan.py +145 -0
config.json CHANGED
@@ -5,8 +5,9 @@
5
  "LlavaBaichuanForCausalLM"
6
  ],
7
  "auto_map": {
8
- "AutoConfig": "baichuan-inc/Baichuan2-7B-Chat--configuration_baichuan.BaichuanConfig",
9
- "AutoModelForCausalLM": "baichuan-inc/Baichuan2-7B-Chat--modeling_baichuan.BaichuanForCausalLM"
 
10
  },
11
  "bos_token_id": 1,
12
  "eos_token_id": 2,
 
5
  "LlavaBaichuanForCausalLM"
6
  ],
7
  "auto_map": {
8
+ "AutoConfig": "llava_baichuan.LlavaBaichuanConfig",
9
+ "AutoModelForCausalLM": "baichuan-inc/Baichuan2-7B-Chat--modeling_baichuan.BaichuanForCausalLM",
10
+ "AutoModelForVisualQuestionAnswering": "llava_baichuan.LlavaBaichuanForCausalLM"
11
  },
12
  "bos_token_id": 1,
13
  "eos_token_id": 2,
llava_baichuan.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from transformers import AutoConfig, AutoModelForCausalLM
7
+
8
+ from transformers.modeling_outputs import CausalLMOutputWithPast
9
+ from transformers.generation.utils import GenerateOutput
10
+
11
+ from llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
12
+
13
+ from configuration_baichuan import BaichuanConfig
14
+ from modeling_baichuan import BaichuanModel, BaichuanForCausalLM
15
+
16
+
17
+ class LlavaBaichuanConfig(BaichuanConfig):
18
+ model_type = "llava_baichuan"
19
+
20
+
21
+ class LlavaBaichuanModel(LlavaMetaModel, BaichuanModel):
22
+ config_class = LlavaBaichuanConfig
23
+
24
+ def __init__(self, config: BaichuanConfig):
25
+ super(LlavaBaichuanModel, self).__init__(config)
26
+
27
+
28
+ class LlavaBaichuanForCausalLM(BaichuanForCausalLM, LlavaMetaForCausalLM):
29
+ config_class = LlavaBaichuanConfig
30
+
31
+ def __init__(self, config):
32
+ super(BaichuanForCausalLM, self).__init__(config)
33
+ self.model = LlavaBaichuanModel(config)
34
+
35
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
36
+
37
+ # Initialize weights and apply final processing
38
+ self.post_init()
39
+
40
+ def get_model(self):
41
+ return self.model
42
+
43
+ def forward(
44
+ self,
45
+ input_ids: torch.LongTensor = None,
46
+ attention_mask: Optional[torch.Tensor] = None,
47
+ position_ids: Optional[torch.LongTensor] = None,
48
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
49
+ inputs_embeds: Optional[torch.FloatTensor] = None,
50
+ labels: Optional[torch.LongTensor] = None,
51
+ use_cache: Optional[bool] = None,
52
+ output_attentions: Optional[bool] = None,
53
+ output_hidden_states: Optional[bool] = None,
54
+ images: Optional[torch.FloatTensor] = None,
55
+ image_sizes: Optional[List[List[int]]] = None,
56
+ return_dict: Optional[bool] = None,
57
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
58
+
59
+ if inputs_embeds is None:
60
+ (
61
+ input_ids,
62
+ position_ids,
63
+ attention_mask,
64
+ past_key_values,
65
+ inputs_embeds,
66
+ labels
67
+ ) = self.prepare_inputs_labels_for_multimodal(
68
+ input_ids,
69
+ position_ids,
70
+ attention_mask,
71
+ past_key_values,
72
+ labels,
73
+ images,
74
+ image_sizes
75
+ )
76
+
77
+ return super().forward(
78
+ input_ids=input_ids,
79
+ attention_mask=attention_mask,
80
+ position_ids=position_ids,
81
+ past_key_values=past_key_values,
82
+ inputs_embeds=inputs_embeds,
83
+ labels=labels,
84
+ use_cache=use_cache,
85
+ output_attentions=output_attentions,
86
+ output_hidden_states=output_hidden_states,
87
+ return_dict=return_dict
88
+ )
89
+
90
+ @torch.no_grad()
91
+ def generate(
92
+ self,
93
+ inputs: Optional[torch.Tensor] = None,
94
+ images: Optional[torch.Tensor] = None,
95
+ image_sizes: Optional[torch.Tensor] = None,
96
+ **kwargs,
97
+ ) -> Union[GenerateOutput, torch.LongTensor]:
98
+ position_ids = kwargs.pop("position_ids", None)
99
+ attention_mask = kwargs.pop("attention_mask", None)
100
+ if "inputs_embeds" in kwargs:
101
+ raise NotImplementedError("`inputs_embeds` is not supported")
102
+
103
+ if images is not None:
104
+ (
105
+ inputs,
106
+ position_ids,
107
+ attention_mask,
108
+ _,
109
+ inputs_embeds,
110
+ _
111
+ ) = self.prepare_inputs_labels_for_multimodal(
112
+ inputs,
113
+ position_ids,
114
+ attention_mask,
115
+ None,
116
+ None,
117
+ images,
118
+ image_sizes=image_sizes
119
+ )
120
+ else:
121
+ inputs_embeds = self.get_model().embed_tokens(inputs)
122
+
123
+ return super().generate(
124
+ position_ids=position_ids,
125
+ attention_mask=attention_mask,
126
+ inputs_embeds=inputs_embeds,
127
+ **kwargs
128
+ )
129
+
130
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
131
+ inputs_embeds=None, **kwargs):
132
+ images = kwargs.pop("images", None)
133
+ image_sizes = kwargs.pop("image_sizes", None)
134
+ inputs = super().prepare_inputs_for_generation(
135
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
136
+ )
137
+ if images is not None:
138
+ inputs['images'] = images
139
+ if image_sizes is not None:
140
+ inputs['image_sizes'] = image_sizes
141
+ return inputs
142
+
143
+
144
+ AutoConfig.register("llava_baichuan", LlavaBaichuanConfig)
145
+ AutoModelForCausalLM.register(LlavaBaichuanConfig, LlavaBaichuanForCausalLM)