mllava-baichuan2-en / README.md
amitha's picture
Update README.md
66b3372 verified
metadata
language:
  - en
license: apache-2.0
tags:
  - llava
  - vlm

The English Baichuan2-7B-Chat VLM trained via LORA for See It from My Perspective: Diagnosing the Western Cultural Bias of Large Vision-Language Models in Image Understanding.

Vision Encoder: CLIP-L

Base LLM: Baichuan2-7B-Chat

Training Corpus:

  • alignment: the corpus used by LLAVA
  • visual instruction tuning: the corpus used by LLAVA

Alignment Script: https://github.com/amith-ananthram/mLLaVA/blob/main/scripts/v1_5/pretrain.sh

Visual Instruction Tuning Script: https://github.com/amith-ananthram/mLLaVA/blob/main/scripts/v1_5/finetune_lora.sh

Usage Example:

import torch
from PIL import Image
from transformers import AutoTokenizer, AutoModelForVisualQuestionAnswering

# from constants.py, utils.py, included as files in this HF release
from constants import IMAGE_TOKEN_INDEX
from utils import tokenizer_image_token, process_images

device = torch.device('cuda')

# load model and vision tower 
model = AutoModelForVisualQuestionAnswering.from_pretrained('amitha/mllava.baichuan2-en', trust_remote_code=True)
model.model.vision_tower.load_model()
model = model.eval().to(device)

image_processor = model.get_vision_tower().image_processor
tokenizer = AutoTokenizer.from_pretrained('baichuan-inc/Baichuan2-7B-Chat', trust_remote_code=True)

prompt = '<reserved_106><image>\nPlease describe this image.<reserved_107>'

input_ids = tokenizer_image_token(
    prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
)
with Image.open("path/to/image.png") as img:
  images = process_images(
    [img.convert('RGB')], image_processor, model.config
  ).to(dtype=torch.float16)
  image_sizes = [img.size]

with torch.no_grad():
  output = model.generate(
    inputs=input_ids.unsqueeze(dim=0).to(device),
    attention_mask=torch.ones(input_ids.shape[0]).unsqueeze(dim=0).to(device),
    images=images.to(device),
    image_sizes=image_sizes
  )

print(tokenizer.batch_decode(output, skip_special_tokens=True))