Upload 3 files
Browse files- scripts/dog.jpg +0 -0
- scripts/joycaption_vqatest1_4bit.py +90 -0
- scripts/requirements.txt +9 -0
scripts/dog.jpg
ADDED
![]() |
scripts/joycaption_vqatest1_4bit.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from:
|
2 |
+
# https://gist.github.com/maedtb/ee16101ca80638011c975ed0c0d78497
|
3 |
+
# https://github.com/fpgaminer/joycaption/issues/3#issuecomment-2619253277
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from transformers import AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig
|
7 |
+
|
8 |
+
IMAGE_PATH = "dog.jpg"
|
9 |
+
PROMPT = "Write a long descriptive caption for this image in a formal tone."
|
10 |
+
MODEL_NAME = "John6666/llama-joycaption-alpha-two-vqa-test-1-nf4"
|
11 |
+
IS_4BIT_MODE = True
|
12 |
+
MODEL_NATIVE_DTYPE = torch.bfloat16
|
13 |
+
|
14 |
+
# Make example output less random
|
15 |
+
torch.manual_seed(42)
|
16 |
+
|
17 |
+
# If 4bit mode is enabled, build our quantization config
|
18 |
+
kwargs = {}
|
19 |
+
if IS_4BIT_MODE:
|
20 |
+
kwargs['quantization_config'] = BitsAndBytesConfig(
|
21 |
+
load_in_4bit=True,
|
22 |
+
bnb_4bit_use_double_quant=True,
|
23 |
+
bnb_4bit_quant_type='nf4',
|
24 |
+
bnb_4bit_compute_dtype=MODEL_NATIVE_DTYPE,
|
25 |
+
bnb_4bit_quant_storage=MODEL_NATIVE_DTYPE,
|
26 |
+
)
|
27 |
+
|
28 |
+
# Load JoyCaption
|
29 |
+
# bfloat16 is the native dtype of the LLM used in JoyCaption (Llama 3.1)
|
30 |
+
# device_map=0 loads the model into the first GPU
|
31 |
+
device = torch.device('cuda:0')
|
32 |
+
processor = AutoProcessor.from_pretrained(MODEL_NAME)
|
33 |
+
llava_model = LlavaForConditionalGeneration.from_pretrained(MODEL_NAME, torch_dtype=MODEL_NATIVE_DTYPE, device_map=device, **kwargs)
|
34 |
+
llava_model.eval()
|
35 |
+
|
36 |
+
# Restore the model's vision's out_proj back to using `nn.Linear` from `nn.Linear4bit`; it is not dynamically quantizable.
|
37 |
+
if IS_4BIT_MODE:
|
38 |
+
attention = llava_model.vision_tower.vision_model.head.attention
|
39 |
+
attention.out_proj = torch.nn.Linear(
|
40 |
+
attention.embed_dim,
|
41 |
+
attention.embed_dim,
|
42 |
+
device=device,
|
43 |
+
dtype=MODEL_NATIVE_DTYPE)
|
44 |
+
|
45 |
+
with torch.no_grad():
|
46 |
+
# Load image
|
47 |
+
image = Image.open(IMAGE_PATH)
|
48 |
+
|
49 |
+
# Build the conversation
|
50 |
+
convo = [
|
51 |
+
{
|
52 |
+
"role": "system",
|
53 |
+
"content": "You are a helpful image captioner.",
|
54 |
+
},
|
55 |
+
{
|
56 |
+
"role": "user",
|
57 |
+
"content": PROMPT,
|
58 |
+
},
|
59 |
+
]
|
60 |
+
|
61 |
+
# Format the conversation
|
62 |
+
# WARNING: HF's handling of chat's on Llava models is very fragile. This specific combination of processor.apply_chat_template(), and processor() works
|
63 |
+
# but if using other combinations always inspect the final input_ids to ensure they are correct. Often times you will end up with multiple <bos> tokens
|
64 |
+
# if not careful, which can make the model perform poorly.
|
65 |
+
convo_string = processor.apply_chat_template(convo, tokenize = False, add_generation_prompt = True)
|
66 |
+
assert isinstance(convo_string, str)
|
67 |
+
|
68 |
+
# Process the inputs
|
69 |
+
inputs = processor(text=[convo_string], images=[image], return_tensors="pt").to('cuda')
|
70 |
+
inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
|
71 |
+
|
72 |
+
# Generate the captions
|
73 |
+
generate_ids = llava_model.generate(
|
74 |
+
**inputs,
|
75 |
+
max_new_tokens=300,
|
76 |
+
do_sample=True,
|
77 |
+
suppress_tokens=None,
|
78 |
+
use_cache=True,
|
79 |
+
temperature=0.6,
|
80 |
+
top_k=None,
|
81 |
+
top_p=0.9,
|
82 |
+
)[0]
|
83 |
+
|
84 |
+
# Trim off the prompt
|
85 |
+
generate_ids = generate_ids[inputs['input_ids'].shape[1]:]
|
86 |
+
|
87 |
+
# Decode the caption
|
88 |
+
caption = processor.tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
89 |
+
caption = caption.strip()
|
90 |
+
print(caption)
|
scripts/requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
huggingface_hub
|
2 |
+
accelerate
|
3 |
+
transformers>=4.46.1
|
4 |
+
sentencepiece
|
5 |
+
peft>=0.13.2
|
6 |
+
numpy<2
|
7 |
+
torch
|
8 |
+
torchvision
|
9 |
+
bitsandbytes
|