Vision-CAIR commited on
Commit
9763171
1 Parent(s): bf4036b

Upload 60 files

Browse files
eval_configs/minigpt4_eval.yaml CHANGED
@@ -1,14 +1,11 @@
1
  model:
2
- arch: mini_gpt4
3
- model_type: pretrain_vicuna
4
- freeze_vit: True
5
- freeze_qformer: True
6
  max_txt_len: 160
7
  end_sym: "###"
8
  low_resource: True
9
- prompt_path: "prompts/alignment.txt"
10
  prompt_template: '###Human: {} ###Assistant: '
11
- ckpt: '/path/to/pretrained/ckpt/'
12
 
13
 
14
  datasets:
 
1
  model:
2
+ arch: minigpt4
3
+ model_type: pretrain_vicuna0
 
 
4
  max_txt_len: 160
5
  end_sym: "###"
6
  low_resource: True
 
7
  prompt_template: '###Human: {} ###Assistant: '
8
+ ckpt: 'please set this value to the path of pretrained checkpoint'
9
 
10
 
11
  datasets:
eval_configs/minigptv2_eval.yaml CHANGED
@@ -5,7 +5,7 @@ model:
5
  end_sym: "</s>"
6
  low_resource: True
7
  prompt_template: '[INST] {} [/INST]'
8
- ckpt: 'minigptv2_checkpoint.pth'
9
  lora_r: 64
10
  lora_alpha: 16
11
 
 
5
  end_sym: "</s>"
6
  low_resource: True
7
  prompt_template: '[INST] {} [/INST]'
8
+ ckpt: 'please set this value to the path of pretrained checkpoint'
9
  lora_r: 64
10
  lora_alpha: 16
11
 
minigpt4/common/dist_utils.py CHANGED
@@ -55,7 +55,10 @@ def is_main_process():
55
 
56
 
57
  def init_distributed_mode(args):
58
- if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
 
 
 
59
  args.rank = int(os.environ["RANK"])
60
  args.world_size = int(os.environ["WORLD_SIZE"])
61
  args.gpu = int(os.environ["LOCAL_RANK"])
 
55
 
56
 
57
  def init_distributed_mode(args):
58
+ if args.distributed is False:
59
+ print("Not using distributed mode")
60
+ return
61
+ elif "RANK" in os.environ and "WORLD_SIZE" in os.environ:
62
  args.rank = int(os.environ["RANK"])
63
  args.world_size = int(os.environ["WORLD_SIZE"])
64
  args.gpu = int(os.environ["LOCAL_RANK"])
minigpt4/configs/models/minigpt_v2.yaml CHANGED
@@ -11,7 +11,7 @@ model:
11
  # generation configs
12
  prompt: ""
13
 
14
- llama_model: "meta-llama/Llama-2-7b-chat-hf"
15
  lora_r: 64
16
  lora_alpha: 16
17
 
 
11
  # generation configs
12
  prompt: ""
13
 
14
+ llama_model: "please set this value to the path of llama2-chat-7b"
15
  lora_r: 64
16
  lora_alpha: 16
17
 
minigpt4/conversation/conversation.py CHANGED
@@ -1,10 +1,11 @@
1
  import argparse
2
  import time
 
3
  from PIL import Image
4
 
5
  import torch
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
7
- from transformers import StoppingCriteria, StoppingCriteriaList
8
 
9
  import dataclasses
10
  from enum import auto, Enum
@@ -39,18 +40,18 @@ class Conversation:
39
  ret = self.system + self.sep
40
  for role, message in self.messages:
41
  if message:
42
- ret += role + ": " + message + self.sep
43
  else:
44
- ret += role + ":"
45
  return ret
46
  elif self.sep_style == SeparatorStyle.TWO:
47
  seps = [self.sep, self.sep2]
48
  ret = self.system + seps[0]
49
  for i, (role, message) in enumerate(self.messages):
50
  if message:
51
- ret += role + ": " + message + seps[i % 2]
52
  else:
53
- ret += role + ":"
54
  return ret
55
  else:
56
  raise ValueError(f"Invalid style: {self.sep_style}")
@@ -106,26 +107,39 @@ class StoppingCriteriaSub(StoppingCriteria):
106
  return False
107
 
108
 
109
- CONV_VISION = Conversation(
110
  system="Give the following image: <Img>ImageContent</Img>. "
111
  "You will be able to see the image once I provide it to you. Please answer my questions.",
112
- roles=("Human", "Assistant"),
113
  messages=[],
114
  offset=2,
115
  sep_style=SeparatorStyle.SINGLE,
116
  sep="###",
117
  )
118
 
 
 
 
 
 
 
 
 
 
 
119
 
120
 
121
  class Chat:
122
- def __init__(self, model, vis_processor, device='cuda:0'):
123
  self.device = device
124
  self.model = model
125
  self.vis_processor = vis_processor
126
- stop_words_ids = [torch.tensor([835]).to(self.device),
127
- torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
128
- self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
 
 
 
129
 
130
  def ask(self, text, conv):
131
  if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
@@ -134,11 +148,19 @@ class Chat:
134
  else:
135
  conv.append_message(conv.roles[0], text)
136
 
137
- def answer(self, conv, img_list, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9,
138
- repetition_penalty=1.0, length_penalty=1, temperature=1.0):
139
  conv.append_message(conv.roles[1], None)
140
  embs = self.get_context_emb(conv, img_list)
141
- outputs = self.model.llama_model.generate(
 
 
 
 
 
 
 
 
142
  inputs_embeds=embs,
143
  max_new_tokens=max_new_tokens,
144
  stopping_criteria=self.stopping_criteria,
@@ -148,18 +170,33 @@ class Chat:
148
  top_p=top_p,
149
  repetition_penalty=repetition_penalty,
150
  length_penalty=length_penalty,
151
- temperature=temperature,
152
  )
153
- output_token = outputs[0]
154
- if output_token[0] == 0:
155
- output_token = output_token[1:]
156
- output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
 
 
 
 
157
  output_text = output_text.split('###')[0] # remove the stop sign '###'
158
  output_text = output_text.split('Assistant:')[-1].strip()
 
159
  conv.messages[-1][1] = output_text
160
  return output_text, output_token.cpu().numpy()
161
 
162
- def upload_img(self, image, conv, img_list):
 
 
 
 
 
 
 
 
 
 
163
  if isinstance(image, str): # is a image path
164
  raw_image = Image.open(image).convert('RGB')
165
  image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
@@ -173,9 +210,12 @@ class Chat:
173
 
174
  image_emb, _ = self.model.encode_img(image)
175
  img_list.append(image_emb)
 
 
176
  conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
 
177
  msg = "Received."
178
- # self.conv.append_message(self.conv.roles[1], msg)
179
  return msg
180
 
181
  def get_context_emb(self, conv, img_list):
@@ -188,7 +228,9 @@ class Chat:
188
  # only add bos to the first seg
189
  for i, seg in enumerate(prompt_segs)
190
  ]
191
- seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
 
 
192
  mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
193
  mixed_embs = torch.cat(mixed_embs, dim=1)
194
  return mixed_embs
 
1
  import argparse
2
  import time
3
+ from threading import Thread
4
  from PIL import Image
5
 
6
  import torch
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
8
+ from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
9
 
10
  import dataclasses
11
  from enum import auto, Enum
 
40
  ret = self.system + self.sep
41
  for role, message in self.messages:
42
  if message:
43
+ ret += role + message + self.sep
44
  else:
45
+ ret += role
46
  return ret
47
  elif self.sep_style == SeparatorStyle.TWO:
48
  seps = [self.sep, self.sep2]
49
  ret = self.system + seps[0]
50
  for i, (role, message) in enumerate(self.messages):
51
  if message:
52
+ ret += role + message + seps[i % 2]
53
  else:
54
+ ret += role
55
  return ret
56
  else:
57
  raise ValueError(f"Invalid style: {self.sep_style}")
 
107
  return False
108
 
109
 
110
+ CONV_VISION_Vicuna0 = Conversation(
111
  system="Give the following image: <Img>ImageContent</Img>. "
112
  "You will be able to see the image once I provide it to you. Please answer my questions.",
113
+ roles=("Human: ", "Assistant: "),
114
  messages=[],
115
  offset=2,
116
  sep_style=SeparatorStyle.SINGLE,
117
  sep="###",
118
  )
119
 
120
+ CONV_VISION_LLama2 = Conversation(
121
+ system="Give the following image: <Img>ImageContent</Img>. "
122
+ "You will be able to see the image once I provide it to you. Please answer my questions.",
123
+ roles=("<s>[INST] ", " [/INST] "),
124
+ messages=[],
125
+ offset=2,
126
+ sep_style=SeparatorStyle.SINGLE,
127
+ sep="",
128
+ )
129
+
130
 
131
 
132
  class Chat:
133
+ def __init__(self, model, vis_processor, device='cuda:0', stopping_criteria=None):
134
  self.device = device
135
  self.model = model
136
  self.vis_processor = vis_processor
137
+
138
+ if stopping_criteria is not None:
139
+ self.stopping_criteria = stopping_criteria
140
+ else:
141
+ stop_words_ids = [torch.tensor([2]).to(self.device)]
142
+ self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
143
 
144
  def ask(self, text, conv):
145
  if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
 
148
  else:
149
  conv.append_message(conv.roles[0], text)
150
 
151
+ def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
152
+ repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000):
153
  conv.append_message(conv.roles[1], None)
154
  embs = self.get_context_emb(conv, img_list)
155
+
156
+ current_max_len = embs.shape[1] + max_new_tokens
157
+ if current_max_len - max_length > 0:
158
+ print('Warning: The number of tokens in current conversation exceeds the max length. '
159
+ 'The model will not see the contexts outside the range.')
160
+ begin_idx = max(0, current_max_len - max_length)
161
+ embs = embs[:, begin_idx:]
162
+
163
+ generation_kwargs = dict(
164
  inputs_embeds=embs,
165
  max_new_tokens=max_new_tokens,
166
  stopping_criteria=self.stopping_criteria,
 
170
  top_p=top_p,
171
  repetition_penalty=repetition_penalty,
172
  length_penalty=length_penalty,
173
+ temperature=float(temperature),
174
  )
175
+ return generation_kwargs
176
+
177
+ def answer(self, conv, img_list, **kargs):
178
+ generation_dict = self.answer_prepare(conv, img_list, **kargs)
179
+
180
+ output_token = self.model.llama_model.generate(**generation_dict)[0]
181
+ output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True)
182
+
183
  output_text = output_text.split('###')[0] # remove the stop sign '###'
184
  output_text = output_text.split('Assistant:')[-1].strip()
185
+
186
  conv.messages[-1][1] = output_text
187
  return output_text, output_token.cpu().numpy()
188
 
189
+ def stream_answer(self, conv, img_list, **kargs):
190
+ generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
191
+ streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True)
192
+ generation_kwargs['streamer'] = streamer
193
+ thread = Thread(target=self.model.llama_model.generate, kwargs=generation_kwargs)
194
+ thread.start()
195
+ return streamer
196
+
197
+ def encode_img(self, img_list):
198
+ image = img_list[0]
199
+ img_list.pop(0)
200
  if isinstance(image, str): # is a image path
201
  raw_image = Image.open(image).convert('RGB')
202
  image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
 
210
 
211
  image_emb, _ = self.model.encode_img(image)
212
  img_list.append(image_emb)
213
+
214
+ def upload_img(self, image, conv, img_list):
215
  conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
216
+ img_list.append(image)
217
  msg = "Received."
218
+
219
  return msg
220
 
221
  def get_context_emb(self, conv, img_list):
 
228
  # only add bos to the first seg
229
  for i, seg in enumerate(prompt_segs)
230
  ]
231
+ print('debug device: ', self.device)
232
+ print('debug model device: ', self.model.device)
233
+ seg_embs = [self.model.embed_tokens(seg_t) for seg_t in seg_tokens]
234
  mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
235
  mixed_embs = torch.cat(mixed_embs, dim=1)
236
  return mixed_embs
minigpt4/datasets/datasets/cc_sbu_dataset.py CHANGED
@@ -22,7 +22,7 @@ class CCSBUDataset(BaseDataset):
22
  def to_dict(self, sample):
23
  return {
24
  "image": sample[0],
25
- "text_input": self.text_processor(sample[1]["caption"]),
26
  }
27
 
28
 
@@ -42,6 +42,6 @@ class CCSBUAlignDataset(CaptionDataset):
42
 
43
  return {
44
  "image": image,
45
- "text_input": caption,
46
  "image_id": self.img_ids[ann["image_id"]],
47
  }
 
22
  def to_dict(self, sample):
23
  return {
24
  "image": sample[0],
25
+ "answer": self.text_processor(sample[1]["caption"]),
26
  }
27
 
28
 
 
42
 
43
  return {
44
  "image": image,
45
+ "answer": caption,
46
  "image_id": self.img_ids[ann["image_id"]],
47
  }
minigpt4/datasets/datasets/laion_dataset.py CHANGED
@@ -26,6 +26,6 @@ class LaionDataset(BaseDataset):
26
  def to_dict(self, sample):
27
  return {
28
  "image": sample[0],
29
- "text_input": self.text_processor(sample[1]["caption"]),
30
  }
31
 
 
26
  def to_dict(self, sample):
27
  return {
28
  "image": sample[0],
29
+ "answer": self.text_processor(sample[1]["caption"]),
30
  }
31
 
minigpt4/models/__init__.py CHANGED
@@ -11,16 +11,18 @@ from omegaconf import OmegaConf
11
 
12
  from minigpt4.common.registry import registry
13
  from minigpt4.models.base_model import BaseModel
14
- from minigpt4.models.blip2 import Blip2Base
15
- from minigpt4.models.mini_gpt4 import MiniGPT4
 
16
  from minigpt4.processors.base_processor import BaseProcessor
17
 
18
 
19
  __all__ = [
20
  "load_model",
21
  "BaseModel",
22
- "Blip2Base",
23
  "MiniGPT4",
 
24
  ]
25
 
26
 
 
11
 
12
  from minigpt4.common.registry import registry
13
  from minigpt4.models.base_model import BaseModel
14
+ from minigpt4.models.minigpt_base import MiniGPTBase
15
+ from minigpt4.models.minigpt4 import MiniGPT4
16
+ from minigpt4.models.minigpt_v2 import MiniGPTv2
17
  from minigpt4.processors.base_processor import BaseProcessor
18
 
19
 
20
  __all__ = [
21
  "load_model",
22
  "BaseModel",
23
+ "MiniGPTBase",
24
  "MiniGPT4",
25
+ "MiniGPTv2"
26
  ]
27
 
28
 
minigpt4/models/base_model.py CHANGED
@@ -5,15 +5,26 @@
5
  For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
  """
7
 
8
- import logging
9
  import os
 
 
10
 
 
11
  import numpy as np
12
  import torch
13
  import torch.nn as nn
 
 
 
 
 
 
 
 
14
  from minigpt4.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
15
  from minigpt4.common.utils import get_abs_path, is_url
16
- from omegaconf import OmegaConf
 
17
 
18
 
19
  class BaseModel(nn.Module):
@@ -24,7 +35,7 @@ class BaseModel(nn.Module):
24
 
25
  @property
26
  def device(self):
27
- return list(self.parameters())[0].device
28
 
29
  def load_checkpoint(self, url_or_filename):
30
  """
@@ -117,131 +128,121 @@ class BaseModel(nn.Module):
117
  else:
118
  return tot
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- class BaseEncoder(nn.Module):
122
- """
123
- Base class for primitive encoders, such as ViT, TimeSformer, etc.
124
- """
125
 
126
- def __init__(self):
127
- super().__init__()
128
 
129
- def forward_features(self, samples, **kwargs):
130
- raise NotImplementedError
131
 
132
- @property
133
- def device(self):
134
- return list(self.parameters())[0].device
135
-
136
-
137
- class SharedQueueMixin:
138
- @torch.no_grad()
139
- def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
140
- # gather keys before updating queue
141
- image_feats = concat_all_gather(image_feat)
142
- text_feats = concat_all_gather(text_feat)
143
-
144
- batch_size = image_feats.shape[0]
145
-
146
- ptr = int(self.queue_ptr)
147
- assert self.queue_size % batch_size == 0 # for simplicity
148
-
149
- # replace the keys at ptr (dequeue and enqueue)
150
- self.image_queue[:, ptr : ptr + batch_size] = image_feats.T
151
- self.text_queue[:, ptr : ptr + batch_size] = text_feats.T
152
-
153
- if idxs is not None:
154
- idxs = concat_all_gather(idxs)
155
- self.idx_queue[:, ptr : ptr + batch_size] = idxs.T
156
-
157
- ptr = (ptr + batch_size) % self.queue_size # move pointer
158
- self.queue_ptr[0] = ptr
159
-
160
-
161
- class MomentumDistilationMixin:
162
- @torch.no_grad()
163
- def copy_params(self):
164
- for model_pair in self.model_pairs:
165
- for param, param_m in zip(
166
- model_pair[0].parameters(), model_pair[1].parameters()
167
- ):
168
- param_m.data.copy_(param.data) # initialize
169
- param_m.requires_grad = False # not update by gradient
170
-
171
- @torch.no_grad()
172
- def _momentum_update(self):
173
- for model_pair in self.model_pairs:
174
- for param, param_m in zip(
175
- model_pair[0].parameters(), model_pair[1].parameters()
176
- ):
177
- param_m.data = param_m.data * self.momentum + param.data * (
178
- 1.0 - self.momentum
179
- )
180
-
181
-
182
- class GatherLayer(torch.autograd.Function):
183
- """
184
- Gather tensors from all workers with support for backward propagation:
185
- This implementation does not cut the gradients as torch.distributed.all_gather does.
186
- """
187
-
188
- @staticmethod
189
- def forward(ctx, x):
190
- output = [
191
- torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
192
- ]
193
- torch.distributed.all_gather(output, x)
194
- return tuple(output)
195
-
196
- @staticmethod
197
- def backward(ctx, *grads):
198
- all_gradients = torch.stack(grads)
199
- torch.distributed.all_reduce(all_gradients)
200
- return all_gradients[torch.distributed.get_rank()]
201
-
202
-
203
- def all_gather_with_grad(tensors):
204
- """
205
- Performs all_gather operation on the provided tensors.
206
- Graph remains connected for backward grad computation.
207
- """
208
- # Queue the gathered tensors
209
- world_size = torch.distributed.get_world_size()
210
- # There is no need for reduction in the single-proc case
211
- if world_size == 1:
212
- return tensors
213
-
214
- # tensor_all = GatherLayer.apply(tensors)
215
- tensor_all = GatherLayer.apply(tensors)
216
-
217
- return torch.cat(tensor_all, dim=0)
218
-
219
-
220
- @torch.no_grad()
221
- def concat_all_gather(tensor):
222
- """
223
- Performs all_gather operation on the provided tensors.
224
- *** Warning ***: torch.distributed.all_gather has no gradient.
225
- """
226
- # if use distributed training
227
- if not is_dist_avail_and_initialized():
228
- return tensor
229
-
230
- tensors_gather = [
231
- torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
232
- ]
233
- torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
234
-
235
- output = torch.cat(tensors_gather, dim=0)
236
- return output
237
-
238
-
239
- def tile(x, dim, n_tile):
240
- init_dim = x.size(dim)
241
- repeat_idx = [1] * x.dim()
242
- repeat_idx[dim] = n_tile
243
- x = x.repeat(*(repeat_idx))
244
- order_index = torch.LongTensor(
245
- np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
246
- )
247
- return torch.index_select(x, dim, order_index.to(x.device))
 
5
  For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
  """
7
 
 
8
  import os
9
+ import logging
10
+ import contextlib
11
 
12
+ from omegaconf import OmegaConf
13
  import numpy as np
14
  import torch
15
  import torch.nn as nn
16
+ from transformers import BertTokenizer, LlamaTokenizer
17
+ from transformers.models.llama.modeling_llama import LlamaForCausalLM
18
+ from peft import (
19
+ LoraConfig,
20
+ get_peft_model,
21
+ prepare_model_for_int8_training,
22
+ )
23
+
24
  from minigpt4.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
25
  from minigpt4.common.utils import get_abs_path, is_url
26
+ from minigpt4.models.eva_vit import create_eva_vit_g
27
+
28
 
29
 
30
  class BaseModel(nn.Module):
 
35
 
36
  @property
37
  def device(self):
38
+ return list(self.parameters())[-1].device
39
 
40
  def load_checkpoint(self, url_or_filename):
41
  """
 
128
  else:
129
  return tot
130
 
131
+ def maybe_autocast(self, dtype=torch.float16):
132
+ # if on cpu, don't use autocast
133
+ # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
134
+ enable_autocast = self.device != torch.device("cpu")
135
+
136
+ if enable_autocast:
137
+ return torch.cuda.amp.autocast(dtype=dtype)
138
+ else:
139
+ return contextlib.nullcontext()
140
+
141
+ @classmethod
142
+ def init_vision_encoder(
143
+ cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision, freeze
144
+ ):
145
+ logging.info('Loading VIT')
146
+
147
+ assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4"
148
+ if not freeze:
149
+ precision = "fp32" # fp16 is not for training
150
+
151
+ visual_encoder = create_eva_vit_g(
152
+ img_size, drop_path_rate, use_grad_checkpoint, precision
153
+ )
154
+
155
+ ln_vision = LayerNorm(visual_encoder.num_features)
156
+
157
+ if freeze:
158
+ for name, param in visual_encoder.named_parameters():
159
+ param.requires_grad = False
160
+ visual_encoder = visual_encoder.eval()
161
+ visual_encoder.train = disabled_train
162
+ for name, param in ln_vision.named_parameters():
163
+ param.requires_grad = False
164
+ ln_vision = ln_vision.eval()
165
+ ln_vision.train = disabled_train
166
+ logging.info("freeze vision encoder")
167
+
168
+ logging.info('Loading VIT Done')
169
+ return visual_encoder, ln_vision
170
+
171
+ def init_llm(cls, llama_model_path, low_resource=False, low_res_device=0, lora_r=0,
172
+ lora_target_modules=["q_proj","v_proj"], **lora_kargs):
173
+ logging.info('Loading LLAMA')
174
+ llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_fast=False)
175
+ llama_tokenizer.pad_token = "$$"
176
+
177
+ if low_resource:
178
+ llama_model = LlamaForCausalLM.from_pretrained(
179
+ llama_model_path,
180
+ torch_dtype=torch.float16,
181
+ load_in_8bit=True,
182
+ device_map={'': low_res_device}
183
+ )
184
+ else:
185
+ llama_model = LlamaForCausalLM.from_pretrained(
186
+ llama_model_path,
187
+ torch_dtype=torch.float16,
188
+ )
189
+
190
+ if lora_r > 0:
191
+ llama_model = prepare_model_for_int8_training(llama_model)
192
+ loraconfig = LoraConfig(
193
+ r=lora_r,
194
+ bias="none",
195
+ task_type="CAUSAL_LM",
196
+ target_modules=lora_target_modules,
197
+ **lora_kargs
198
+ )
199
+ llama_model = get_peft_model(llama_model, loraconfig)
200
+
201
+ llama_model.print_trainable_parameters()
202
+
203
+ else:
204
+ for name, param in llama_model.named_parameters():
205
+ param.requires_grad = False
206
+ logging.info('Loading LLAMA Done')
207
+ return llama_model, llama_tokenizer
208
+
209
+
210
+ def load_from_pretrained(self, url_or_filename):
211
+ if is_url(url_or_filename):
212
+ cached_file = download_cached_file(
213
+ url_or_filename, check_hash=False, progress=True
214
+ )
215
+ checkpoint = torch.load(cached_file, map_location="cpu")
216
+ elif os.path.isfile(url_or_filename):
217
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
218
+ else:
219
+ raise RuntimeError("checkpoint url or path is invalid")
220
+
221
+ state_dict = checkpoint["model"]
222
+
223
+ msg = self.load_state_dict(state_dict, strict=False)
224
+
225
+ # logging.info("Missing keys {}".format(msg.missing_keys))
226
+ logging.info("load checkpoint from %s" % url_or_filename)
227
+
228
+ return msg
229
+
230
+
231
+ def disabled_train(self, mode=True):
232
+ """Overwrite model.train with this function to make sure train/eval mode
233
+ does not change anymore."""
234
+ return self
235
+
236
+
237
+ class LayerNorm(nn.LayerNorm):
238
+ """Subclass torch's LayerNorm to handle fp16."""
239
+
240
+ def forward(self, x: torch.Tensor):
241
+ orig_type = x.dtype
242
+ ret = super().forward(x.type(torch.float32))
243
+ return ret.type(orig_type)
244
+
245
 
 
 
 
 
246
 
 
 
247
 
 
 
248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/models/minigpt_base.py CHANGED
@@ -7,6 +7,7 @@ import torch.nn as nn
7
 
8
  from minigpt4.common.registry import registry
9
  from minigpt4.models.base_model import BaseModel
 
10
 
11
 
12
 
@@ -365,8 +366,8 @@ class MiniGPTBase(BaseModel):
365
  do_sample=do_sample,
366
  min_length=min_length,
367
  top_p=top_p,
368
- repetition_penalty=repetition_penalty
369
- # stopping_criteria=stopping_criteria,
370
  )
371
 
372
  answers = []
 
7
 
8
  from minigpt4.common.registry import registry
9
  from minigpt4.models.base_model import BaseModel
10
+ from transformers import StoppingCriteria, StoppingCriteriaList
11
 
12
 
13
 
 
366
  do_sample=do_sample,
367
  min_length=min_length,
368
  top_p=top_p,
369
+ repetition_penalty=repetition_penalty,
370
+ stopping_criteria=stopping_criteria,
371
  )
372
 
373
  answers = []
minigpt4/models/modeling_llama.py CHANGED
@@ -1,628 +1,17 @@
1
- # This script is based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
2
-
3
- """ PyTorch LLaMA model."""
4
  import math
5
  from typing import List, Optional, Tuple, Union
6
 
7
  import torch
8
- import torch.utils.checkpoint
9
- from torch import nn
10
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
11
-
12
- from transformers.activations import ACT2FN
13
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
14
- from transformers.modeling_utils import PreTrainedModel
15
- from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
16
- from transformers.models.llama.configuration_llama import LlamaConfig
17
-
18
-
19
- logger = logging.get_logger(__name__)
20
-
21
- _CONFIG_FOR_DOC = "LlamaConfig"
22
-
23
-
24
- # Copied from transformers.models.bart.modeling_bart._make_causal_mask
25
- def _make_causal_mask(
26
- input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
27
- ):
28
- """
29
- Make causal mask used for bi-directional self-attention.
30
- """
31
- bsz, tgt_len = input_ids_shape
32
- mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
33
- mask_cond = torch.arange(mask.size(-1), device=device)
34
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
35
- mask = mask.to(dtype)
36
-
37
- if past_key_values_length > 0:
38
- mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
39
- return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
40
-
41
-
42
- # Copied from transformers.models.bart.modeling_bart._expand_mask
43
- def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
44
- """
45
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
46
- """
47
- bsz, src_len = mask.size()
48
- tgt_len = tgt_len if tgt_len is not None else src_len
49
-
50
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
51
-
52
- inverted_mask = 1.0 - expanded_mask
53
-
54
- return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
55
-
56
-
57
- class LlamaRMSNorm(nn.Module):
58
- def __init__(self, hidden_size, eps=1e-6):
59
- """
60
- LlamaRMSNorm is equivalent to T5LayerNorm
61
- """
62
- super().__init__()
63
- self.weight = nn.Parameter(torch.ones(hidden_size))
64
- self.variance_epsilon = eps
65
-
66
- def forward(self, hidden_states):
67
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
68
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
69
-
70
- # convert into half-precision if necessary
71
- if self.weight.dtype in [torch.float16, torch.bfloat16]:
72
- hidden_states = hidden_states.to(self.weight.dtype)
73
-
74
- return self.weight * hidden_states
75
-
76
-
77
- class LlamaRotaryEmbedding(torch.nn.Module):
78
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
79
- super().__init__()
80
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
81
- self.register_buffer("inv_freq", inv_freq)
82
-
83
- # Build here to make `torch.jit.trace` work.
84
- self.max_seq_len_cached = max_position_embeddings
85
- t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
86
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
87
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
88
- emb = torch.cat((freqs, freqs), dim=-1)
89
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
90
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
91
-
92
- def forward(self, x, seq_len=None):
93
- # x: [bs, num_attention_heads, seq_len, head_size]
94
- # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
95
- if seq_len > self.max_seq_len_cached:
96
- self.max_seq_len_cached = seq_len
97
- t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
98
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
99
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
100
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
101
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
102
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
103
- return (
104
- self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
105
- self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
106
- )
107
-
108
-
109
- def rotate_half(x):
110
- """Rotates half the hidden dims of the input."""
111
- x1 = x[..., : x.shape[-1] // 2]
112
- x2 = x[..., x.shape[-1] // 2 :]
113
- return torch.cat((-x2, x1), dim=-1)
114
-
115
-
116
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
117
- gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
118
- gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
119
- cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
120
- sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
121
- q_embed = (q * cos) + (rotate_half(q) * sin)
122
- k_embed = (k * cos) + (rotate_half(k) * sin)
123
- return q_embed, k_embed
124
-
125
-
126
- class LlamaMLP(nn.Module):
127
- def __init__(
128
- self,
129
- hidden_size: int,
130
- intermediate_size: int,
131
- hidden_act: str,
132
- ):
133
- super().__init__()
134
- self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
135
- self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
136
- self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
137
- self.act_fn = ACT2FN[hidden_act]
138
-
139
- def forward(self, x):
140
- return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
141
-
142
-
143
- class LlamaAttention(nn.Module):
144
- """Multi-headed attention from 'Attention Is All You Need' paper"""
145
-
146
- def __init__(self, config: LlamaConfig):
147
- super().__init__()
148
- self.config = config
149
- self.hidden_size = config.hidden_size
150
- self.num_heads = config.num_attention_heads
151
- self.head_dim = self.hidden_size // self.num_heads
152
- self.max_position_embeddings = config.max_position_embeddings
153
-
154
- if (self.head_dim * self.num_heads) != self.hidden_size:
155
- raise ValueError(
156
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
157
- f" and `num_heads`: {self.num_heads})."
158
- )
159
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
160
- self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
161
- self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
162
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
163
- self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
164
-
165
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
166
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
167
-
168
- def forward(
169
- self,
170
- hidden_states: torch.Tensor,
171
- attention_mask: Optional[torch.Tensor] = None,
172
- position_ids: Optional[torch.LongTensor] = None,
173
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
174
- output_attentions: bool = False,
175
- use_cache: bool = False,
176
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
177
- bsz, q_len, _ = hidden_states.size()
178
-
179
- query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
180
- key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
181
- value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
182
-
183
- kv_seq_len = key_states.shape[-2]
184
- if past_key_value is not None:
185
- kv_seq_len += past_key_value[0].shape[-2]
186
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
187
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
188
- # [bsz, nh, t, hd]
189
-
190
- if past_key_value is not None:
191
- # reuse k, v, self_attention
192
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
193
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
194
-
195
- past_key_value = (key_states, value_states) if use_cache else None
196
-
197
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
198
-
199
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
200
- raise ValueError(
201
- f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
202
- f" {attn_weights.size()}"
203
- )
204
-
205
- if attention_mask is not None:
206
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
207
- raise ValueError(
208
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
209
- )
210
- attn_weights = attn_weights + attention_mask
211
- attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
212
-
213
- # upcast attention to fp32
214
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
215
- attn_output = torch.matmul(attn_weights, value_states)
216
-
217
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
218
- raise ValueError(
219
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
220
- f" {attn_output.size()}"
221
- )
222
-
223
- attn_output = attn_output.transpose(1, 2)
224
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
225
-
226
- attn_output = self.o_proj(attn_output)
227
-
228
- if not output_attentions:
229
- attn_weights = None
230
-
231
- return attn_output, attn_weights, past_key_value
232
-
233
-
234
- class LlamaDecoderLayer(nn.Module):
235
- def __init__(self, config: LlamaConfig):
236
- super().__init__()
237
- self.hidden_size = config.hidden_size
238
- self.self_attn = LlamaAttention(config=config)
239
- self.mlp = LlamaMLP(
240
- hidden_size=self.hidden_size,
241
- intermediate_size=config.intermediate_size,
242
- hidden_act=config.hidden_act,
243
- )
244
- self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
245
- self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
246
-
247
- def forward(
248
- self,
249
- hidden_states: torch.Tensor,
250
- attention_mask: Optional[torch.Tensor] = None,
251
- position_ids: Optional[torch.LongTensor] = None,
252
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
253
- output_attentions: Optional[bool] = False,
254
- use_cache: Optional[bool] = False,
255
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
256
- """
257
- Args:
258
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
259
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
260
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
261
- output_attentions (`bool`, *optional*):
262
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
263
- returned tensors for more detail.
264
- use_cache (`bool`, *optional*):
265
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
266
- (see `past_key_values`).
267
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
268
- """
269
-
270
- residual = hidden_states
271
-
272
- hidden_states = self.input_layernorm(hidden_states)
273
-
274
- # Self Attention
275
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
276
- hidden_states=hidden_states,
277
- attention_mask=attention_mask,
278
- position_ids=position_ids,
279
- past_key_value=past_key_value,
280
- output_attentions=output_attentions,
281
- use_cache=use_cache,
282
- )
283
- hidden_states = residual + hidden_states
284
-
285
- # Fully Connected
286
- residual = hidden_states
287
- hidden_states = self.post_attention_layernorm(hidden_states)
288
- hidden_states = self.mlp(hidden_states)
289
- hidden_states = residual + hidden_states
290
-
291
- outputs = (hidden_states,)
292
-
293
- if output_attentions:
294
- outputs += (self_attn_weights,)
295
-
296
- if use_cache:
297
- outputs += (present_key_value,)
298
-
299
- return outputs
300
-
301
-
302
- LLAMA_START_DOCSTRING = r"""
303
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
304
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
305
- etc.)
306
-
307
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
308
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
309
- and behavior.
310
-
311
- Parameters:
312
- config ([`LlamaConfig`]):
313
- Model configuration class with all the parameters of the model. Initializing with a config file does not
314
- load the weights associated with the model, only the configuration. Check out the
315
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
316
- """
317
-
318
-
319
- @add_start_docstrings(
320
- "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
321
- LLAMA_START_DOCSTRING,
322
- )
323
- class LlamaPreTrainedModel(PreTrainedModel):
324
- config_class = LlamaConfig
325
- base_model_prefix = "model"
326
- supports_gradient_checkpointing = True
327
- _no_split_modules = ["LlamaDecoderLayer"]
328
- _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
329
-
330
- def _init_weights(self, module):
331
- std = self.config.initializer_range
332
- if isinstance(module, nn.Linear):
333
- module.weight.data.normal_(mean=0.0, std=std)
334
- if module.bias is not None:
335
- module.bias.data.zero_()
336
- elif isinstance(module, nn.Embedding):
337
- module.weight.data.normal_(mean=0.0, std=std)
338
- if module.padding_idx is not None:
339
- module.weight.data[module.padding_idx].zero_()
340
-
341
- def _set_gradient_checkpointing(self, module, value=False):
342
- if isinstance(module, LlamaModel):
343
- module.gradient_checkpointing = value
344
-
345
 
346
- LLAMA_INPUTS_DOCSTRING = r"""
347
- Args:
348
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
349
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
350
- it.
351
 
352
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
353
- [`PreTrainedTokenizer.__call__`] for details.
354
 
355
- [What are input IDs?](../glossary#input-ids)
356
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
357
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
358
-
359
- - 1 for tokens that are **not masked**,
360
- - 0 for tokens that are **masked**.
361
-
362
- [What are attention masks?](../glossary#attention-mask)
363
-
364
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
365
- [`PreTrainedTokenizer.__call__`] for details.
366
-
367
- If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
368
- `past_key_values`).
369
-
370
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
371
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
372
- information on the default strategy.
373
-
374
- - 1 indicates the head is **not masked**,
375
- - 0 indicates the head is **masked**.
376
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
377
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
378
- config.n_positions - 1]`.
379
-
380
- [What are position IDs?](../glossary#position-ids)
381
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
382
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
383
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
384
- `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
385
-
386
- Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
387
- blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
388
-
389
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
390
- don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
391
- `decoder_input_ids` of shape `(batch_size, sequence_length)`.
392
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
393
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
394
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
395
- model's internal embedding lookup matrix.
396
- use_cache (`bool`, *optional*):
397
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
398
- `past_key_values`).
399
- output_attentions (`bool`, *optional*):
400
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
401
- tensors for more detail.
402
- output_hidden_states (`bool`, *optional*):
403
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
404
- more detail.
405
- return_dict (`bool`, *optional*):
406
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
407
- """
408
-
409
-
410
- @add_start_docstrings(
411
- "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
412
- LLAMA_START_DOCSTRING,
413
- )
414
- class LlamaModel(LlamaPreTrainedModel):
415
- """
416
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
417
-
418
- Args:
419
- config: LlamaConfig
420
- """
421
-
422
- def __init__(self, config: LlamaConfig):
423
- super().__init__(config)
424
- self.padding_idx = config.pad_token_id
425
- self.vocab_size = config.vocab_size
426
-
427
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
428
- self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
429
- self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
430
-
431
- self.gradient_checkpointing = False
432
- # Initialize weights and apply final processing
433
- self.post_init()
434
-
435
- def get_input_embeddings(self):
436
- return self.embed_tokens
437
-
438
- def set_input_embeddings(self, value):
439
- self.embed_tokens = value
440
-
441
- # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
442
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
443
- # create causal mask
444
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
445
- combined_attention_mask = None
446
- if input_shape[-1] > 1:
447
- combined_attention_mask = _make_causal_mask(
448
- input_shape,
449
- inputs_embeds.dtype,
450
- device=inputs_embeds.device,
451
- past_key_values_length=past_key_values_length,
452
- )
453
-
454
- if attention_mask is not None:
455
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
456
- expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
457
- inputs_embeds.device
458
- )
459
- combined_attention_mask = (
460
- expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
461
- )
462
-
463
- return combined_attention_mask
464
-
465
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
466
- def forward(
467
- self,
468
- input_ids: torch.LongTensor = None,
469
- attention_mask: Optional[torch.Tensor] = None,
470
- position_ids: Optional[torch.LongTensor] = None,
471
- past_key_values: Optional[List[torch.FloatTensor]] = None,
472
- inputs_embeds: Optional[torch.FloatTensor] = None,
473
- query_embeds: Optional[torch.FloatTensor] = None,
474
- use_cache: Optional[bool] = None,
475
- output_attentions: Optional[bool] = None,
476
- output_hidden_states: Optional[bool] = None,
477
- return_dict: Optional[bool] = None,
478
- ) -> Union[Tuple, BaseModelOutputWithPast]:
479
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
480
- output_hidden_states = (
481
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
482
- )
483
- use_cache = use_cache if use_cache is not None else self.config.use_cache
484
-
485
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
486
-
487
- # retrieve input_ids and inputs_embeds
488
- if input_ids is not None and inputs_embeds is not None:
489
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
490
- elif input_ids is not None:
491
- batch_size, seq_length = input_ids.shape
492
- elif inputs_embeds is not None:
493
- batch_size, seq_length, _ = inputs_embeds.shape
494
- else:
495
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
496
-
497
- if inputs_embeds is None:
498
- inputs_embeds = self.embed_tokens(input_ids)
499
- if query_embeds is not None:
500
- inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1)
501
- batch_size, seq_length, _ = inputs_embeds.shape
502
-
503
- seq_length_with_past = seq_length
504
- past_key_values_length = 0
505
-
506
- if past_key_values is not None:
507
- past_key_values_length = past_key_values[0][0].shape[2]
508
- seq_length_with_past = seq_length_with_past + past_key_values_length
509
-
510
- if position_ids is None:
511
- device = input_ids.device if input_ids is not None else inputs_embeds.device
512
- position_ids = torch.arange(
513
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
514
- )
515
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
516
- else:
517
- position_ids = position_ids.view(-1, seq_length).long()
518
-
519
- # embed positions
520
- if attention_mask is None:
521
- attention_mask = torch.ones(
522
- (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
523
- )
524
- attention_mask = self._prepare_decoder_attention_mask(
525
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
526
- )
527
-
528
- hidden_states = inputs_embeds
529
-
530
- if self.gradient_checkpointing and self.training:
531
- if use_cache:
532
- logger.warning_once(
533
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
534
- )
535
- use_cache = False
536
-
537
- # decoder layers
538
- all_hidden_states = () if output_hidden_states else None
539
- all_self_attns = () if output_attentions else None
540
- next_decoder_cache = () if use_cache else None
541
-
542
- for idx, decoder_layer in enumerate(self.layers):
543
- if output_hidden_states:
544
- all_hidden_states += (hidden_states,)
545
-
546
- past_key_value = past_key_values[idx] if past_key_values is not None else None
547
-
548
- if self.gradient_checkpointing and self.training:
549
-
550
- def create_custom_forward(module):
551
- def custom_forward(*inputs):
552
- # None for past_key_value
553
- return module(*inputs, output_attentions, None)
554
-
555
- return custom_forward
556
-
557
- layer_outputs = torch.utils.checkpoint.checkpoint(
558
- create_custom_forward(decoder_layer),
559
- hidden_states,
560
- attention_mask,
561
- position_ids,
562
- None,
563
- )
564
- else:
565
- layer_outputs = decoder_layer(
566
- hidden_states,
567
- attention_mask=attention_mask,
568
- position_ids=position_ids,
569
- past_key_value=past_key_value,
570
- output_attentions=output_attentions,
571
- use_cache=use_cache,
572
- )
573
-
574
- hidden_states = layer_outputs[0]
575
-
576
- if use_cache:
577
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
578
-
579
- if output_attentions:
580
- all_self_attns += (layer_outputs[1],)
581
-
582
- hidden_states = self.norm(hidden_states)
583
-
584
- # add hidden states from the last decoder layer
585
- if output_hidden_states:
586
- all_hidden_states += (hidden_states,)
587
-
588
- next_cache = next_decoder_cache if use_cache else None
589
- if not return_dict:
590
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
591
- return BaseModelOutputWithPast(
592
- last_hidden_state=hidden_states,
593
- past_key_values=next_cache,
594
- hidden_states=all_hidden_states,
595
- attentions=all_self_attns,
596
- )
597
-
598
-
599
- class LlamaForCausalLM(LlamaPreTrainedModel):
600
- def __init__(self, config):
601
- super().__init__(config)
602
- self.model = LlamaModel(config)
603
-
604
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
605
-
606
- # Initialize weights and apply final processing
607
- self.post_init()
608
-
609
- def get_input_embeddings(self):
610
- return self.model.embed_tokens
611
-
612
- def set_input_embeddings(self, value):
613
- self.model.embed_tokens = value
614
-
615
- def get_output_embeddings(self):
616
- return self.lm_head
617
-
618
- def set_output_embeddings(self, new_embeddings):
619
- self.lm_head = new_embeddings
620
-
621
- def set_decoder(self, decoder):
622
- self.model = decoder
623
-
624
- def get_decoder(self):
625
- return self.model
626
 
627
  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
628
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
@@ -633,12 +22,12 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
633
  position_ids: Optional[torch.LongTensor] = None,
634
  past_key_values: Optional[List[torch.FloatTensor]] = None,
635
  inputs_embeds: Optional[torch.FloatTensor] = None,
636
- query_embeds: Optional[torch.FloatTensor] = None,
637
  labels: Optional[torch.LongTensor] = None,
638
  use_cache: Optional[bool] = None,
639
  output_attentions: Optional[bool] = None,
640
  output_hidden_states: Optional[bool] = None,
641
  return_dict: Optional[bool] = None,
 
642
  ) -> Union[Tuple, CausalLMOutputWithPast]:
643
  r"""
644
  Args:
@@ -657,13 +46,13 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
657
  >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
658
  >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
659
 
660
- >>> prompt = "Hey, are you consciours? Can you talk to me?"
661
  >>> inputs = tokenizer(prompt, return_tensors="pt")
662
 
663
  >>> # Generate
664
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
665
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
666
- "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
667
  ```"""
668
 
669
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -679,7 +68,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
679
  position_ids=position_ids,
680
  past_key_values=past_key_values,
681
  inputs_embeds=inputs_embeds,
682
- query_embeds=query_embeds,
683
  use_cache=use_cache,
684
  output_attentions=output_attentions,
685
  output_hidden_states=output_hidden_states,
@@ -687,7 +75,13 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
687
  )
688
 
689
  hidden_states = outputs[0]
690
- logits = self.lm_head(hidden_states)
 
 
 
 
 
 
691
 
692
  loss = None
693
  if labels is not None:
@@ -695,12 +89,14 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
695
  shift_logits = logits[..., :-1, :].contiguous()
696
  shift_labels = labels[..., 1:].contiguous()
697
  # Flatten the tokens
698
- loss_fct = CrossEntropyLoss()
699
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
700
  shift_labels = shift_labels.view(-1)
701
  # Enable model parallelism
702
  shift_labels = shift_labels.to(shift_logits.device)
703
  loss = loss_fct(shift_logits, shift_labels)
 
 
704
 
705
  if not return_dict:
706
  output = (logits,) + outputs[1:]
@@ -713,43 +109,3 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
713
  hidden_states=outputs.hidden_states,
714
  attentions=outputs.attentions,
715
  )
716
-
717
- def prepare_inputs_for_generation(
718
- self, input_ids, query_embeds=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
719
- ):
720
- if past_key_values:
721
- input_ids = input_ids[:, -1:]
722
-
723
- position_ids = kwargs.get("position_ids", None)
724
- if attention_mask is not None and position_ids is None:
725
- # create position_ids on the fly for batch generation
726
- position_ids = attention_mask.long().cumsum(-1) - 1
727
- position_ids.masked_fill_(attention_mask == 0, 1)
728
- if past_key_values:
729
- position_ids = position_ids[:, -1].unsqueeze(-1)
730
- query_embeds = None
731
-
732
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
733
- if inputs_embeds is not None and past_key_values is None:
734
- model_inputs = {"inputs_embeds": inputs_embeds}
735
- else:
736
- model_inputs = {"input_ids": input_ids}
737
-
738
- model_inputs.update(
739
- {
740
- "position_ids": position_ids,
741
- "query_embeds": query_embeds,
742
- "past_key_values": past_key_values,
743
- "use_cache": kwargs.get("use_cache"),
744
- "attention_mask": attention_mask,
745
- }
746
- )
747
- return model_inputs
748
-
749
- @staticmethod
750
- def _reorder_cache(past_key_values, beam_idx):
751
- reordered_past = ()
752
- for layer_past in past_key_values:
753
- reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
754
- return reordered_past
755
-
 
 
 
 
1
  import math
2
  from typing import List, Optional, Tuple, Union
3
 
4
  import torch
5
+ import torch.nn.functional as F
6
+ from torch.nn import CrossEntropyLoss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
9
+ from transformers.modeling_outputs import CausalLMOutputWithPast
10
+ from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC
11
+ from transformers.models.llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLMOrig
 
12
 
 
 
13
 
14
+ class LlamaForCausalLM(LlamaForCausalLMOrig):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
17
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
22
  position_ids: Optional[torch.LongTensor] = None,
23
  past_key_values: Optional[List[torch.FloatTensor]] = None,
24
  inputs_embeds: Optional[torch.FloatTensor] = None,
 
25
  labels: Optional[torch.LongTensor] = None,
26
  use_cache: Optional[bool] = None,
27
  output_attentions: Optional[bool] = None,
28
  output_hidden_states: Optional[bool] = None,
29
  return_dict: Optional[bool] = None,
30
+ reduction: Optional[str] = "mean",
31
  ) -> Union[Tuple, CausalLMOutputWithPast]:
32
  r"""
33
  Args:
 
46
  >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
47
  >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
48
 
49
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
50
  >>> inputs = tokenizer(prompt, return_tensors="pt")
51
 
52
  >>> # Generate
53
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
54
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
55
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
56
  ```"""
57
 
58
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
68
  position_ids=position_ids,
69
  past_key_values=past_key_values,
70
  inputs_embeds=inputs_embeds,
 
71
  use_cache=use_cache,
72
  output_attentions=output_attentions,
73
  output_hidden_states=output_hidden_states,
 
75
  )
76
 
77
  hidden_states = outputs[0]
78
+ if self.config.pretraining_tp > 1:
79
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
80
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
81
+ logits = torch.cat(logits, dim=-1)
82
+ else:
83
+ logits = self.lm_head(hidden_states)
84
+ logits = logits.float()
85
 
86
  loss = None
87
  if labels is not None:
 
89
  shift_logits = logits[..., :-1, :].contiguous()
90
  shift_labels = labels[..., 1:].contiguous()
91
  # Flatten the tokens
92
+ loss_fct = CrossEntropyLoss(reduction=reduction)
93
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
94
  shift_labels = shift_labels.view(-1)
95
  # Enable model parallelism
96
  shift_labels = shift_labels.to(shift_logits.device)
97
  loss = loss_fct(shift_logits, shift_labels)
98
+ if reduction == "none":
99
+ loss = loss.view(logits.size(0), -1).mean(1)
100
 
101
  if not return_dict:
102
  output = (logits,) + outputs[1:]
 
109
  hidden_states=outputs.hidden_states,
110
  attentions=outputs.attentions,
111
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/runners/runner_base.py CHANGED
@@ -627,14 +627,14 @@ class RunnerBase:
627
  cached_file = download_cached_file(
628
  url_or_filename, check_hash=False, progress=True
629
  )
630
- checkpoint = torch.load(cached_file, map_location=self.device, strict=False)
631
  elif os.path.isfile(url_or_filename):
632
- checkpoint = torch.load(url_or_filename, map_location=self.device, strict=False)
633
  else:
634
  raise RuntimeError("checkpoint url or path is invalid")
635
 
636
  state_dict = checkpoint["model"]
637
- self.unwrap_dist_model(self.model).load_state_dict(state_dict)
638
 
639
  self.optimizer.load_state_dict(checkpoint["optimizer"])
640
  if self.scaler and "scaler" in checkpoint:
 
627
  cached_file = download_cached_file(
628
  url_or_filename, check_hash=False, progress=True
629
  )
630
+ checkpoint = torch.load(cached_file, map_location=self.device)
631
  elif os.path.isfile(url_or_filename):
632
+ checkpoint = torch.load(url_or_filename, map_location=self.device)
633
  else:
634
  raise RuntimeError("checkpoint url or path is invalid")
635
 
636
  state_dict = checkpoint["model"]
637
+ self.unwrap_dist_model(self.model).load_state_dict(state_dict,strict=False)
638
 
639
  self.optimizer.load_state_dict(checkpoint["optimizer"])
640
  if self.scaler and "scaler" in checkpoint: