--- license: other license_name: intel-research-use-license license_link: LICENSE --- # LLaVA-Llama3 Model Card _This model card corresponds to the instruction tuned 8B version of the model with the CLIP-based vision encoder._ ## Overview `llava-llama-3-8b` is a large multimodal model (LMM) trained using the [LLaVA-v1.5 framework](https://arxiv.org/abs/2310.03744) with the 8-billion parameter [`meta-llama/Meta-Llama-3-8B-Instruct`](https://huggingface.co/meta-llama/Meta-Llama-3-8B) model as language backbone. ## Uses The model has been finetuned for multimodal benchmark evaluations, but can also be used as a multimodal chatbot. ## Bias, Risks, and Limitations This model has not been assessed for harm or biases, and should not be used for sensitive applications where it may cause harm. ## Training Details The `llava-llama-3-8b` model was trained on a 4 node cluster with a total of 32 Gaudi 2 accelerators. ### Training Data The model was trained using the LLaVA-v1.5 data mixture. This is listed as follows: - 558K filtered image-text pairs from LAION/CC/SBU, captioned by BLIP. - 158K GPT-generated multimodal instruction-following data. - 450K academic-task-oriented VQA data mixture. - 40K ShareGPT data. ## Evaluation | Model | Metrics | |----------|------------------| | ScienceQA| 72.9797 | | MMVet | 31.9725 | | llavaw | 56.9/61.9/73.6/65.7 | | Pope Acc | 87.33, F1 86.5 | | GQA | 60.6138 | | MMVP | 36 | ## License The weights are released under the Intel Research Use License Agreement (see LICENSE file) All usage code is licensed Apache 2.0 ## Usage Please note, we only provide the trained weights difference and do not provide a copy of the base [`meta-llama/Meta-Llama-3-8B-Instruct`](https://huggingface.co/meta-llama/Meta-Llama-3-8B) model. Any use of these weights requires a separate download of the base model. ```python # Copyright 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import requests import torch from PIL import Image from transformers import AutoProcessor, AutoModelForPreTraining import transformers def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result def add_model_a_to_b(model_a, model_b): state_dict_a = model_a.state_dict() state_dict_b = model_b.state_dict() # Ensure keys match before subtraction if set(state_dict_a.keys()) != set(state_dict_b.keys()): raise ValueError("Model state dicts do not have the same keys.") for key in state_dict_a: if state_dict_a[key].shape != state_dict_b[key].shape: raise ValueError(f"Shape mismatch for key '{key}': {state_dict_a[key].shape} vs {state_dict_b[key].shape}") # Subtract model_a's weights from model_b for the matching key state_dict_b[key] = state_dict_b[key] + state_dict_a[key] # Update model_b with the new weights model_b.load_state_dict(state_dict_b) output_checkpoint = "" # set if you don't want to merge every time hf_checkpoint = "Intel/llava-llama-3-8b" device = "cuda" if torch.cuda.is_available() else "cpu" processor = AutoProcessor.from_pretrained(hf_checkpoint) model = AutoModelForPreTraining.from_pretrained(hf_checkpoint) if model.language_model.model.embed_tokens.weight[-1].sum() == 0: print("adding llama3 weights") model_id = "meta-llama/Meta-Llama-3-8B-Instruct" pipeline = transformers.pipeline( "text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="cpu", ) llama3 = pipeline.model add_model_a_to_b(llama3, model.language_model) if output_checkpoint: print("saving weights, so no adding is needed again") model.save_pretrained(output_checkpoint) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) prompt = processor.tokenizer.apply_chat_template( [{'role': 'user', 'content': "\nWhat's the content of the image?"}], tokenize=False, add_generation_prompt=True ) url = "https://www.ilankelman.org/stopsigns/australia.jpg" image = Image.open(requests.get(url, stream=True).raw) #original llava pads with mean, HF llava pads with zeros image = expand2square(image, tuple(int(x*255) for x in processor.image_processor.image_mean)) inputs = processor(text=prompt, images=image, return_tensors="pt").to(device) # Generate generate_ids = model.generate(**inputs, max_length=30) output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] print(output) ```