matthewlyleolson commited on
Commit
3371fdb
1 Parent(s): a678b33

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +138 -5
README.md CHANGED
@@ -1,5 +1,138 @@
1
- ---
2
- license: other
3
- license_name: intel-research-use
4
- license_link: LICENSE
5
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ license_name: intel-research-use-license
4
+ license_link: LICENSE
5
+ ---
6
+
7
+ # LLaVA-Llama3 Model Card
8
+
9
+ _This model card corresponds to the instruction tuned 8B version of the model with the CLIP-based vision encoder._
10
+
11
+
12
+ ## Overview
13
+
14
+ `llava-llama-3-8b` is a large multimodal model (LMM) trained using the [LLaVA-v1.5 framework](https://arxiv.org/abs/2310.03744) with the 8-billion parameter [`meta-llama/Meta-Llama-3-8B-Instruct`](https://huggingface.co/meta-llama/Meta-Llama-3-8B) model as language backbone.
15
+
16
+ ## Uses
17
+
18
+ The model has been finetuned for multimodal benchmark evaluations, but can also be used as a multimodal chatbot.
19
+
20
+ ## Bias, Risks, and Limitations
21
+
22
+ This model has not been assessed for harm or biases, and should not be used for sensitive applications where it may cause harm.
23
+
24
+ ## Training Details
25
+
26
+ The `llava-llama-3-8b` model was trained on a 4 node cluster with a total of 32 Gaudi 2 accelerators.
27
+
28
+ ### Training Data
29
+
30
+ The model was trained using the LLaVA-v1.5 data mixture.
31
+
32
+ This is listed as follows:
33
+
34
+ - 558K filtered image-text pairs from LAION/CC/SBU, captioned by BLIP.
35
+ - 158K GPT-generated multimodal instruction-following data.
36
+ - 450K academic-task-oriented VQA data mixture.
37
+ - 40K ShareGPT data.
38
+
39
+ ## Evaluation
40
+
41
+ | Model | Metrics |
42
+ |----------|------------------|
43
+ | ScienceQA| 72.9797 |
44
+ | MMVet | 31.9725 |
45
+ | llavaw | 56.9/61.9/73.6/65.7 |
46
+ | Pope Acc | 87.33, F1 86.5 |
47
+ | GQA | 60.6138 |
48
+ | MMVP | 36 |
49
+
50
+ ## License
51
+ The weights are released under the Intel Research Use License Agreement (see LICENSE file)
52
+ All usage code is licensed Apache 2.0
53
+
54
+ ## Usage
55
+
56
+ Please note, we only provide the trained weights difference and do not provide a copy of the base [`meta-llama/Meta-Llama-3-8B-Instruct`](https://huggingface.co/meta-llama/Meta-Llama-3-8B) model. Any use of these weights requires a separate download of the base model.
57
+
58
+ ```
59
+ # Copyright 2024 Intel Corporation
60
+ # SPDX-License-Identifier: Apache-2.0
61
+
62
+ import requests
63
+ import torch
64
+ from PIL import Image
65
+ from transformers import AutoProcessor, AutoModelForPreTraining
66
+ import transformers
67
+
68
+ def expand2square(pil_img, background_color):
69
+ width, height = pil_img.size
70
+ if width == height:
71
+ return pil_img
72
+ elif width > height:
73
+ result = Image.new(pil_img.mode, (width, width), background_color)
74
+ result.paste(pil_img, (0, (width - height) // 2))
75
+ return result
76
+ else:
77
+ result = Image.new(pil_img.mode, (height, height), background_color)
78
+ result.paste(pil_img, ((height - width) // 2, 0))
79
+ return result
80
+
81
+ def add_model_a_to_b(model_a, model_b):
82
+ state_dict_a = model_a.state_dict()
83
+ state_dict_b = model_b.state_dict()
84
+
85
+ # Ensure keys match before subtraction
86
+ if set(state_dict_a.keys()) != set(state_dict_b.keys()):
87
+ raise ValueError("Model state dicts do not have the same keys.")
88
+
89
+ for key in state_dict_a:
90
+ if state_dict_a[key].shape != state_dict_b[key].shape:
91
+ raise ValueError(f"Shape mismatch for key '{key}': {state_dict_a[key].shape} vs {state_dict_b[key].shape}")
92
+ # Subtract model_a's weights from model_b for the matching key
93
+ state_dict_b[key] = state_dict_b[key] + state_dict_a[key]
94
+
95
+ # Update model_b with the new weights
96
+ model_b.load_state_dict(state_dict_b)
97
+
98
+ output_checkpoint = "" # set if you don't want to merge every time
99
+ hf_checkpoint = "Intel/llava-llama-3-8b-old"
100
+ device = "cuda" if torch.cuda.is_available() else "cpu"
101
+
102
+ processor = AutoProcessor.from_pretrained(hf_checkpoint)
103
+ model = AutoModelForPreTraining.from_pretrained(hf_checkpoint)
104
+ if model.language_model.model.embed_tokens.weight[-1].sum() == 0:
105
+ print("adding llama3 weights")
106
+ model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
107
+ pipeline = transformers.pipeline(
108
+ "text-generation",
109
+ model=model_id,
110
+ model_kwargs={"torch_dtype": torch.bfloat16},
111
+ device_map="cpu",
112
+ )
113
+ llama3 = pipeline.model
114
+ add_model_a_to_b(llama3, model.language_model)
115
+ if output_checkpoint:
116
+ print("saving weights, so no adding is needed again")
117
+ model.save_pretrained(output_checkpoint)
118
+
119
+ device = "cuda" if torch.cuda.is_available() else "cpu"
120
+ model.to(device)
121
+
122
+ prompt = processor.tokenizer.apply_chat_template(
123
+ [{'role': 'user', 'content': "<image>\nWhat's the content of the image?"}],
124
+ tokenize=False,
125
+ add_generation_prompt=True
126
+ )
127
+
128
+ url = "https://www.ilankelman.org/stopsigns/australia.jpg"
129
+ image = Image.open(requests.get(url, stream=True).raw)
130
+
131
+ #original llava pads with mean, HF llava pads with zeros
132
+ image = expand2square(image, tuple(int(x*255) for x in processor.image_processor.image_mean))
133
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
134
+ # Generate
135
+ generate_ids = model.generate(**inputs, max_length=30)
136
+ output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
137
+ print(output)
138
+ ```