Image-Text-to-Text
Japanese
English
mitsu-koh commited on
Commit
7055e81
·
1 Parent(s): ea92cde

Upload 5 files

Browse files
Files changed (5) hide show
  1. README.md +177 -0
  2. checkpoint.pth +3 -0
  3. customized_mini_gpt4.py +149 -0
  4. rinna.png +0 -0
  5. sample.jpg +0 -0
README.md CHANGED
@@ -1,3 +1,180 @@
1
  ---
2
  license: mit
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
+ datasets:
4
+ - conceptual_12m
5
+ - HuggingFaceM4/COCO
6
+ - visual_genome
7
+ language:
8
+ - ja
9
+ - en
10
  ---
11
+
12
+
13
+ # bilingual-gpt-neox-4b-minigpt4
14
+
15
+ ![rinna-icon](./rinna.png)
16
+
17
+ # Overview
18
+ This repository provides an English-Japanese bilingual multimodal conversational model like MiniGPT-4 by combining GPT-NeoX model of 3.8 billion parameters and BLIP-2.
19
+
20
+ The model is based on [`rinna/bilingual-gpt-neox-4b`](https://huggingface.co/rinna/bilingual-gpt-neox-4b) and [BLIP-2](https://huggingface.co/docs/transformers/main/model_doc/blip-2).
21
+
22
+ * **Model architecture**
23
+
24
+ Similar with [BLIP-2](https://huggingface.co/docs/transformers/main/model_doc/blip-2) and [Vision-CAIR/MiniGPT-4](https://huggingface.co/Vision-CAIR/MiniGPT-4), the model consists of an LLM, vision-encoder with ViT and Q-Former, and linear-layer for connecting the LLM and vision-encoder.
25
+
26
+ [`rinna/bilingual-gpt-neox-4b`](https://huggingface.co/rinna/bilingual-gpt-neox-4b) (A 36-layer, 2816-hidden-size transformer-based language model) is used as the LLM instead of [Vicuna](https://github.com/lm-sys/FastChat), which is used in the original [Vision-CAIR/MiniGPT-4](https://huggingface.co/Vision-CAIR/MiniGPT-4).
27
+
28
+ * **Finetuning**
29
+
30
+ The finetuning data is the subset of the following datasets.
31
+ * English datasets
32
+ * [Conceptual 12M (CC12M)](https://huggingface.co/datasets/conceptual_12m)
33
+ * [COCO 2014](https://huggingface.co/datasets/HuggingFaceM4/COCO)
34
+ * [Visual Genome](https://huggingface.co/datasets/visual_genome)
35
+ * Japanese datasets
36
+ * [STAIR-captions](https://github.com/STAIR-Lab-CIT/STAIR-captions)
37
+ * [Japanese Visual Genome VQA dataset](https://github.com/yahoojapan/ja-vg-vqa)
38
+
39
+ Based on the implementation of [Vision-CAIR/MiniGPT-4](https://huggingface.co/Vision-CAIR/MiniGPT-4), only "first pretraining stage" described in [MiniGPT-4 paper](https://arxiv.org/abs/2304.10592) with the above datasets was conducted, and "second-stage finetuning" proposed in the paper with an aligned image-text dataset created with ChatGPT was NOT conducted.
40
+
41
+ * **Model Series**
42
+
43
+ | Variant | Link |
44
+ | :-- | :--|
45
+ | Bilingual 4B MiniGPT4 | https://huggingface.co/rinna/bilingual-gpt-neox-4b-minigpt4 |
46
+ | Bilingual 4B SFT | https://huggingface.co/rinna/bilingual-gpt-neox-4b-instruction-sft |
47
+ | Bilingual 4B 8K | https://huggingface.co/rinna/bilingual-gpt-neox-4b-8k |
48
+ | Bilingual 4B | https://huggingface.co/rinna/bilingual-gpt-neox-4b |
49
+ | Japanese 3.6B PPO | https://huggingface.co/rinna/japanese-gpt-neox-3.6b-instruction-ppo |
50
+ | Japanese 3.6B SFT-v2 | https://huggingface.co/rinna/japanese-gpt-neox-3.6b-instruction-sft-v2 |
51
+ | Japanese 3.6B SFT | https://huggingface.co/rinna/japanese-gpt-neox-3.6b-instruction-sft |
52
+ | Japanese 3.6B | https://huggingface.co/rinna/japanese-gpt-neox-3.6b |
53
+
54
+ * **Authors**
55
+
56
+ [Koh Mitsuda](https://huggingface.co/mitsu-koh), [Tianyu Zhao](https://huggingface.co/tianyuz), and [Kei Sawada](https://huggingface.co/keisawada)
57
+
58
+ ---
59
+
60
+ # I/O Format
61
+ A special format has been adopted to construct inputs.
62
+ * An input prompt is formatted as a conversation between `ユーザー` and `システム`.
63
+ * Each input utterance consists of (1) its speaker (`"ユーザー"` or `"システム"`), (2) a colon (`":"`), (3) a whitespace (`" "`), and (4) utterance text (e.g. `"猫はどんな体勢をしていますか?"`).
64
+ * An utterance including an image is formatted as (1) its speaker (`"ユーザー"`), (2) a colon (`":"`), (3) a whitespace (`" "`), (4) a placeholder of the image (`"<Img><ImageHere></Img>"`), (5) another whitespace (`" "`), (6) utterance text (e.g. `"What can you see?"`).
65
+ * The placeholder (`<ImageHere>`) is automatically replaced with the embedding of an input image in the function `get_context_emb`.
66
+ * The input prompt should be ended with `"システム: "` to acknowledge the model to generate a response.
67
+ * All the utterances in the input prompt should be separated by a newline `\n`.
68
+
69
+ Following is an example to construct input from a conversation.
70
+ ~~~python
71
+ prompt = [
72
+ {
73
+ "speaker": "ユーザー",
74
+ "text": "<Img><ImageHere></Img> What can you see?"
75
+ },
76
+ {
77
+ "speaker": "システム",
78
+ "text": "a cat on a table with a laptop"
79
+ },
80
+ {
81
+ "speaker": "ユーザー",
82
+ "text": "猫はどんな体勢をしていますか?"
83
+ },
84
+ ]
85
+ prompt = [
86
+ f"{uttr['speaker']}: {uttr['text']}"
87
+ for uttr in prompt
88
+ ]
89
+ prompt = "\n".join(prompt)
90
+ prompt = (
91
+ prompt
92
+ + "\n"
93
+ + "システム: "
94
+ )
95
+ print(prompt)
96
+ """
97
+ ユーザー: <Img><ImageHere></Img> What can you see?
98
+ システム: a cat on a table with a laptop
99
+ ユーザー: 猫はどんな体勢をしていますか?
100
+ システム:
101
+ """
102
+ ~~~
103
+
104
+ ---
105
+
106
+ # How to use the model
107
+
108
+ **1. Download dependencies**
109
+
110
+ * BLIP-2 implementation included in MiniGPT-4 is used for inference.
111
+ * `customized_mini_gpt4.py` is a script to replace LLM from LLaMA architecture to GPT-NeoX one.
112
+ * `checkpoint.pth` is a finetuned weight of the linear layer (file size: 177 MB).
113
+
114
+ ```bash
115
+ git clone https://github.com/Vision-CAIR/MiniGPT-4.git
116
+ cd ./MiniGPT-4
117
+ git checkout 22d8888 # latest version as of July 31, 2023.
118
+ wget https://huggingface.co/rinna/bilingual-gpt-neox-4b-minigpt4/resolve/main/customized_mini_gpt4.py
119
+ wget https://huggingface.co/rinna/bilingual-gpt-neox-4b-minigpt4/resolve/main/checkpoint.pth
120
+ ```
121
+
122
+ **2. Inference**
123
+
124
+ Please run this script in `MiniGPT-4` directory.
125
+
126
+ ~~~~python
127
+ import torch
128
+ import requests
129
+ from PIL import Image
130
+ from minigpt4.processors.blip_processors import Blip2ImageEvalProcessor
131
+ from customized_mini_gpt4 import CustomizedMiniGPT4
132
+
133
+ ckpt_path = "./checkpoint.pth"
134
+
135
+ model = CustomizedMiniGPT4(gpt_neox_model="rinna/bilingual-gpt-neox-4b")
136
+ tokenizer = model.gpt_neox_tokenizer
137
+
138
+ if torch.cuda.is_available():
139
+ model = model.to("cuda")
140
+
141
+ if ckpt_path is not None:
142
+ print("Load BLIP2-LLM Checkpoint: {}".format(ckpt_path))
143
+ ckpt = torch.load(ckpt_path, map_location="cpu")
144
+ model.load_state_dict(ckpt['model'], strict=False)
145
+
146
+ vis_processor = Blip2ImageEvalProcessor()
147
+
148
+ image_url = "https://huggingface.co/rinna/bilingual-gpt-neox-4b-minigpt4-preview/resolve/main/sample.jpg"
149
+ raw_image = Image.open(requests.get(image_url, stream=True).raw).convert('RGB')
150
+ image = vis_processor(raw_image).unsqueeze(0).to(model.device)
151
+ image_emb = model.encode_img(image)
152
+
153
+ embs = model.get_context_emb(prompt, [image_emb])
154
+
155
+ output_ids = model.gpt_neox_model.generate(
156
+ inputs_embeds=embs,
157
+ max_new_tokens=512,
158
+ do_sample=True,
159
+ temperature=1.0,
160
+ top_p=0.85,
161
+ pad_token_id=tokenizer.pad_token_id,
162
+ bos_token_id=tokenizer.bos_token_id,
163
+ eos_token_id=tokenizer.eos_token_id
164
+ )
165
+
166
+ output = tokenizer.decode(output_ids.tolist()[0], skip_special_tokens=True)
167
+ print(output)
168
+ """横になっています。"""
169
+ ~~~~
170
+
171
+ ---
172
+
173
+ # Acknowledgement
174
+
175
+ * [Vision-CAIR/MiniGPT-4](https://huggingface.co/Vision-CAIR/MiniGPT-4)
176
+ * [BLIP-2](https://huggingface.co/docs/transformers/main/model_doc/blip-2)
177
+ * [Lavis](https://github.com/salesforce/LAVIS)
178
+
179
+ # Licenese
180
+ [The MIT license](https://opensource.org/licenses/MIT)
checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:170633d5fd203d8b5b4d6d5ca3e3ce5bc8bb6cf66671ee96c0a6f4a1e38197e6
3
+ size 177115114
customized_mini_gpt4.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from minigpt4.models.mini_gpt4 import MiniGPT4
5
+ from minigpt4.models.blip2 import Blip2Base, disabled_train
6
+
7
+ from transformers.models.gpt_neox import GPTNeoXForCausalLM
8
+ from transformers import AutoTokenizer
9
+
10
+
11
+ class CustomizedGPTNeoXForCausalLM(GPTNeoXForCausalLM):
12
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
13
+ input_shape = input_ids.shape
14
+
15
+ # cut decoder_input_ids if past is used
16
+ if past_key_values and past_key_values[0] is not None:
17
+ input_ids = input_ids[:, -1:]
18
+
19
+ position_ids = kwargs.get("position_ids", None)
20
+ if attention_mask is not None and position_ids is None:
21
+ # create position_ids on the fly for batch generation
22
+ position_ids = attention_mask.long().cumsum(-1) - 1
23
+ position_ids.masked_fill_(attention_mask == 0, 1)
24
+ if past_key_values:
25
+ position_ids = position_ids[:, -1].unsqueeze(-1)
26
+
27
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
28
+ if attention_mask is None:
29
+ attention_mask = input_ids.new_ones(input_shape)
30
+
31
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
32
+ if inputs_embeds is not None and past_key_values is None:
33
+ model_inputs = {"inputs_embeds": inputs_embeds}
34
+ else:
35
+ model_inputs = {"input_ids": input_ids}
36
+
37
+ model_inputs.update(
38
+ {
39
+ "attention_mask": attention_mask,
40
+ "position_ids": position_ids,
41
+ "past_key_values": past_key_values,
42
+ }
43
+ )
44
+ return model_inputs
45
+
46
+
47
+ class CustomizedMiniGPT4(Blip2Base):
48
+ """
49
+ BLIP2 GPT-NeoX model.
50
+ """
51
+ def __init__(
52
+ self,
53
+ gpt_neox_model="rinna/bilingual-gpt-neox-4b",
54
+ vit_model="eva_clip_g",
55
+ q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
56
+ img_size=224,
57
+ drop_path_rate=0,
58
+ use_grad_checkpoint=False,
59
+ vit_precision="fp16",
60
+ freeze_vit=True,
61
+ freeze_qformer=True,
62
+ num_query_token=32,
63
+ low_resource=False, # use 8 bit and put vit in cpu
64
+ device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
65
+ ):
66
+ super().__init__()
67
+
68
+ self.tokenizer = self.init_tokenizer()
69
+ self.low_resource = low_resource
70
+
71
+ print('Loading VIT', flush=True)
72
+ self.visual_encoder, self.ln_vision = self.init_vision_encoder(
73
+ vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
74
+ )
75
+ if freeze_vit:
76
+ for name, param in self.visual_encoder.named_parameters():
77
+ param.requires_grad = False
78
+ self.visual_encoder = self.visual_encoder.eval()
79
+ self.visual_encoder.train = disabled_train
80
+ for name, param in self.ln_vision.named_parameters():
81
+ param.requires_grad = False
82
+ self.ln_vision = self.ln_vision.eval()
83
+ self.ln_vision.train = disabled_train
84
+ print("freeze vision encoder")
85
+ print('Loading VIT Done')
86
+
87
+ print('Loading Q-Former', flush=True)
88
+ self.Qformer, self.query_tokens = self.init_Qformer(
89
+ num_query_token, self.visual_encoder.num_features
90
+ )
91
+ self.Qformer.cls = None
92
+ self.Qformer.bert.embeddings.word_embeddings = None
93
+ self.Qformer.bert.embeddings.position_embeddings = None
94
+ for layer in self.Qformer.bert.encoder.layer:
95
+ layer.output = None
96
+ layer.intermediate = None
97
+ self.load_from_pretrained(url_or_filename=q_former_model)
98
+
99
+ if freeze_qformer:
100
+ for name, param in self.Qformer.named_parameters():
101
+ param.requires_grad = False
102
+ self.Qformer = self.Qformer.eval()
103
+ self.Qformer.train = disabled_train
104
+ self.query_tokens.requires_grad = False
105
+ print("freeze Qformer")
106
+ print('Loading Q-Former Done')
107
+
108
+ print('Loading LLM', flush=True)
109
+ self.gpt_neox_tokenizer = AutoTokenizer.from_pretrained(gpt_neox_model, use_fast=False)
110
+
111
+ if self.low_resource:
112
+ self.gpt_neox_model = CustomizedGPTNeoXForCausalLM.from_pretrained(
113
+ gpt_neox_model,
114
+ torch_dtype=torch.float16,
115
+ load_in_8bit=True,
116
+ device_map={'': device_8bit}
117
+ )
118
+ else:
119
+ self.gpt_neox_model = CustomizedGPTNeoXForCausalLM.from_pretrained(
120
+ gpt_neox_model,
121
+ torch_dtype=torch.float16,
122
+ )
123
+
124
+ for name, param in self.gpt_neox_model.named_parameters():
125
+ param.requires_grad = False
126
+ print('Loading LLM Done')
127
+
128
+ self.llama_proj = nn.Linear(
129
+ self.Qformer.config.hidden_size, self.gpt_neox_model.config.hidden_size
130
+ )
131
+
132
+ def vit_to_cpu(self):
133
+ MiniGPT4.vit_to_cpu(self)
134
+
135
+ def encode_img(self, image):
136
+ inputs_gpt_neox, _ = MiniGPT4.encode_img(self, image)
137
+ return inputs_gpt_neox
138
+
139
+ def get_context_emb(self, prompt, img_list):
140
+ prompt_segs = prompt.split('<ImageHere>')
141
+ assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
142
+ seg_tokens = [
143
+ self.gpt_neox_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(self.device).input_ids
144
+ for i, seg in enumerate(prompt_segs)
145
+ ]
146
+ seg_embs = [self.gpt_neox_model.gpt_neox.embed_in(seg_t) for seg_t in seg_tokens]
147
+ mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
148
+ mixed_embs = torch.cat(mixed_embs, dim=1)
149
+ return mixed_embs
rinna.png ADDED
sample.jpg ADDED