File size: 3,609 Bytes
5bbcd6e
5fd720a
 
3abaaec
 
 
5bbcd6e
5fd720a
 
 
3abaaec
5fd720a
3abaaec
5fd720a
3abaaec
5fd720a
3abaaec
e0b8288
3abaaec
 
 
 
e0b8288
3abaaec
 
 
e0b8288
3abaaec
e0b8288
3abaaec
 
 
 
 
 
e0b8288
3abaaec
 
e0b8288
3abaaec
 
 
e0b8288
3abaaec
e0b8288
3abaaec
 
 
 
e0b8288
3abaaec
 
e0b8288
3abaaec
 
 
 
 
 
 
e0b8288
3abaaec
e0b8288
3abaaec
 
 
e0b8288
3abaaec
 
 
 
 
e0b8288
3abaaec
e0b8288
3abaaec
e0b8288
3abaaec
 
e0b8288
3abaaec
 
 
 
 
 
e0b8288
3abaaec
e0b8288
3abaaec
5fd720a
3abaaec
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
---
library_name: peft
base_model: mistralai/Mistral-7B-Instruct-v0.1
license: apache-2.0
datasets:
- liuhaotian/LLaVA-Instruct-150K
---

# Model Card for Model ID

![Text Meme](meme.jpg)

Is text really all you need? Probably not, but the least we can do is try. This repo contains a QLoRA fine-tune of Mistral-7B on the original Llava-150K-Instruct dataset; however, each image is encoded as a base64 representation. With enough data, can a LLM learn to "see" just from text? Early results say absolutely not, but I am committed to burning my GPU credits regardless of how bad the result.

I do believe in the future we will see a "simplification" of architectures designed to work for multiple modalities. LLaVA, for example, combines a vision encoder with a pre-trained LLM. Perhaps models of the future will have a joint-representation for both images and text, and not have to rely on splicing 2 models together. For example, perhaps [Token-Free Models](https://arxiv.org/html/2401.13660v1) could be trained on multi-modal byte representations of inputs. Of course, this would be extremely computationally expensive compared to modern vision models, but maybe 10-20 years down the line it's not that big of a deal?

To use this model, you can load the base Mistral model and the adapter:

```python
import torch
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.1"
ADAPTER_MODEL = "seanmor5/mistral-7b-instruct-vision-64-qlora"
MAX_SEQ_LEN = 2048

device = "cuda"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
model = PeftModel.from_pretrained(model, ADAPTER_MODEL)

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, model_max_length=MAX_SEQ_LEN)
tokenizer.pad_token = tokenizer.eos_token
```

One challenge with this approach is sequence length. High resolution images are large, and when encoded in base64 create prohibitively large sequences. To naively overcome this we aggressively resize and downsample the image:

```python
import base64
from io import BytesIO
from PIL import Image

TARGET_SIZE = (224, 168)
TARGET_QUALITY = 5

def downsample(path):
    img = Image.open(path)
    img = img.resize(TARGET_SIZE, Image.ANTIALIAS)
    buf = BytesIO()
    img.save(buf, optimize=True, quality=5, format="JPEG")
    return f"<image>{base64.b64encode(buf.getvalue()).decode()}</image>"
```

Then we can use the default Mistral chat output, ensuring our images are encoded properly within the text:

```python
def replace_image(seq, img):
    return seq.replace("<image>", downsample(img))

prompt = (
    "<image>\nWhat is the dog doing in this photo?"
)
prompt = replace_image(prompt, "dog.jpg")
print(prompt)

messages = [{"role": "user", "content": prompt}]

encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")

model_inputs = encodeds.to(device)
model.to(device)

generated_ids = model.generate(
    input_ids=model_inputs, max_new_tokens=1000, do_sample=True
)
decoded = tokenizer.batch_decode(generated_ids)
print(decoded[0])
```

Even with this aggressive downsampling, some images result in sequences that are too large. Tough luck. I also did not do this experiment with any other format but JPEG images, and I did not consider the effect that the image format may have had on the model's performance.

## Model Details

- **Developed by:** Sean Moriarity
- **License:** Apache 2.0