llm-compressor-my-repo / quantize_qwen2_5_vl.py
n00b001's picture
save
c2bdc87 unverified
#!/usr/bin/env python
"""
Specialized script for quantizing Qwen2.5-VL models with sequential onloading
Handles quantization of Qwen2_5_VLForConditionalGeneration models properly
"""
import base64
from io import BytesIO
from typing import Optional, Union, Dict, Any
import torch
from datasets import load_dataset
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, AutoTokenizer
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier
from llmcompressor.modifiers.awq import AWQModifier, AWQMapping
from llmcompressor.utils import dispatch_for_generation
def create_qwen2_5_vl_data_collator():
"""Create a data collator for Qwen2.5-VL models that handles multimodal inputs."""
def data_collator(batch):
assert len(batch) == 1
return {key: torch.tensor(value) if isinstance(value, (list, int, float)) else value
for key, value in batch[0].items()}
return data_collator
def create_qwen2_5_vl_preprocessing_fn(processor, max_sequence_length: int = 2048):
"""Create a preprocessing function for Qwen2.5-VL datasets."""
def preprocess_and_tokenize(example):
# Handle different image formats
if 'image' in example:
# Process image
if hasattr(example['image'], 'save'):
# PIL Image object
buffered = BytesIO()
example["image"].save(buffered, format="PNG")
encoded_image = base64.b64encode(buffered.getvalue())
encoded_image_text = encoded_image.decode("utf-8")
base64_qwen = f"data:image;base64,{encoded_image_text}"
else:
# Already a string or other format
base64_qwen = str(example["image"])
else:
# If there's no image field, try 'img' or similar
img_key = None
for key in example.keys():
if 'image' in key.lower() or 'img' in key.lower():
img_key = key
break
if img_key:
if hasattr(example[img_key], 'save'):
buffered = BytesIO()
example[img_key].save(buffered, format="PNG")
encoded_image = base64.b64encode(buffered.getvalue())
encoded_image_text = encoded_image.decode("utf-8")
base64_qwen = f"data:image;base64,{encoded_image_text}"
else:
base64_qwen = str(example[img_key])
else:
# If no image, create a simple text-only example
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": example.get('text', example.get('content', 'What can you tell me about this?'))},
],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return processor(
text=[text],
padding=False,
max_length=max_sequence_length,
truncation=True,
)
# Create message with image
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": base64_qwen},
{"type": "text", "text": "What does the image show?"},
],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
# tokenize
return processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=False,
max_length=max_sequence_length,
truncation=True,
)
return preprocess_and_tokenize
def get_qwen2_5_vl_quantization_recipe(method: str, scheme: str = "W4A16"):
"""
Creates the appropriate quantization recipe for Qwen2.5-VL models.
Args:
method: Quantization method ("GPTQ", "AWQ", or "FP8")
scheme: Quantization scheme (e.g., "W4A16", "W8A8", "FP8")
Returns:
List of modifiers for the quantization recipe
"""
if method == "GPTQ":
return [
GPTQModifier(
targets="Linear",
scheme=scheme,
ignore=["lm_head", "re:visual.*", "re:model.visual.*"],
sequential_targets=["Qwen2_5_VLDecoderLayer"], # This is key for the architecture
),
]
elif method == "AWQ":
# Create AWQ mappings for Qwen2.5-VL architecture
mappings = [
AWQMapping(
"re:.*input_layernorm", ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"]
),
AWQMapping("re:.*v_proj", ["re:.*o_proj"]),
AWQMapping(
"re:.*post_attention_layernorm", ["re:.*gate_proj", "re:.*up_proj"]
),
AWQMapping("re:.*up_proj", ["re:.*down_proj"]),
]
return [
AWQModifier(
ignore=["lm_head", "re:visual.*", "re:model.visual.*"],
scheme="W4A16_ASYM" if scheme == "W4A16" else scheme,
targets=["Linear"],
mappings=mappings,
),
]
elif method == "FP8":
return [
QuantizationModifier(
scheme="FP8",
targets="Linear",
ignore=["lm_head", "re:visual.*", "re:model.visual.*"]
)
]
else:
raise ValueError(f"Unsupported quantization method: {method}")
def quantize_qwen2_5_vl_model(
model_id: str,
quantization_method: str,
output_dir: Optional[str] = None,
dataset_id: str = "lmms-lab/flickr30k",
dataset_split: str = "test[:512]",
num_calibration_samples: int = 512,
max_sequence_length: int = 2048,
scheme: str = "W4A16",
trust_remote_code: bool = True,
):
"""
Quantizes a Qwen2.5-VL model with proper architecture handling and sequential onloading.
Args:
model_id: Hugging Face model ID to quantize
quantization_method: Method to use ("GPTQ", "AWQ", or "FP8")
output_dir: Directory to save the quantized model
dataset_id: Dataset ID for calibration
dataset_split: Dataset split for calibration
num_calibration_samples: Number of samples to use for calibration
max_sequence_length: Maximum sequence length for processing
scheme: Quantization scheme (e.g., "W4A16", "W8A8")
trust_remote_code: Whether to trust remote code in model loading
Returns:
Quantized model
"""
print(f"Loading model: {model_id}")
# Handle different device scenarios properly
if torch.cuda.is_available():
try:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16, # Use float16 to save memory on GPU
device_map="auto", # Auto device mapping for memory efficiency
trust_remote_code=trust_remote_code
)
except RuntimeError as e:
if "out of memory" in str(e).lower() or "offload_dir" in str(e):
print(f"Memory issue detected, using offloading: {e}")
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
offload_folder=temp_dir,
max_memory={0: "24GB", "cpu": "48GB"},
trust_remote_code=trust_remote_code
)
else:
raise
else:
# CPU only
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float32, # Use float32 on CPU
device_map="cpu",
trust_remote_code=trust_remote_code
)
print(f"Loading processor for: {model_id}")
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=trust_remote_code)
# If output directory not specified, create one based on model and method
if not output_dir:
model_name = model_id.rstrip("/").split("/")[-1]
output_dir = f"{model_name}-{scheme.replace(':', '-')}-{quantization_method}"
print(f"Output directory: {output_dir}")
# Load dataset and preprocess
print(f"Loading dataset: {dataset_id}")
try:
ds = load_dataset(dataset_id, split=dataset_split)
except Exception as e:
print(f"Failed to load {dataset_id}, trying alternative text-only dataset: {e}")
# If the image dataset fails, try a text-only dataset
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:512]")
# We'll need to adjust preprocessing for text-only data
ds = ds.shuffle(seed=42)
# Apply preprocessing
preprocess_fn = create_qwen2_5_vl_preprocessing_fn(processor, max_sequence_length)
try:
ds = ds.map(preprocess_fn, remove_columns=ds.column_names if hasattr(ds, 'column_names') else [])
except Exception as e:
print(f"Preprocessing failed: {e}")
print("Trying simpler preprocessing with text-only data...")
# Fallback: use text-only preprocessing
def text_only_preprocess(example):
text = example.get('text', example.get('content', str(example)))
if not isinstance(text, str):
text = str(text)
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": text[:500] + "..." if len(text) > 500 else text}, # Limit length
],
}
]
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
return processor(text=[prompt], padding=False, max_length=max_sequence_length, truncation=True)
ds = ds.map(text_only_preprocess, remove_columns=ds.column_names if hasattr(ds, 'column_names') else [])
# Define data collator
data_collator = create_qwen2_5_vl_data_collator()
# Create recipe
recipe = get_qwen2_5_vl_quantization_recipe(quantization_method, scheme)
print(f"Starting quantization with method: {quantization_method}")
print(f"Using recipe: {recipe}")
# Perform oneshot quantization with sequential targets and proper handling
oneshot(
model=model,
tokenizer=processor, # Use processor as tokenizer for Qwen2.5-VL
dataset=ds,
recipe=recipe,
max_seq_length=max_sequence_length,
num_calibration_samples=num_calibration_samples,
trust_remote_code_model=trust_remote_code,
data_collator=data_collator,
# Use sequential onloading for memory efficiency
sequential_targets=["Qwen2_5_VLDecoderLayer"],
save_compressed=True,
output_dir=output_dir,
)
print(f"Quantization completed! Model saved to: {output_dir}")
# Save the processor as well
processor.save_pretrained(output_dir)
return model
def test_quantized_model(model, processor, max_sequence_length: int = 2048):
"""
Tests the quantized model with a sample generation.
"""
print("========== SAMPLE GENERATION ==============")
try:
dispatch_for_generation(model)
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "http://images.cocodataset.org/train2017/000000231895.jpg",
},
{"type": "text", "text": "Please describe the animal in this image\n"},
],
}
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[prompt],
images=image_inputs,
videos=video_inputs,
padding=False,
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
).to(model.device)
output = model.generate(**inputs, max_new_tokens=100)
result = processor.decode(output[0], skip_special_tokens=True)
print(result)
print("==========================================")
return result
except Exception as e:
print(f"Test generation failed: {e}")
print("Trying text-only generation...")
# Try with text-only
try:
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Hello, how are you today?"},
],
}
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(
text=[prompt],
padding=False,
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
).to(model.device)
output = model.generate(**inputs, max_new_tokens=50)
result = processor.decode(output[0], skip_special_tokens=True)
print(result)
print("==========================================")
return result
except Exception as e2:
print(f"Text-only generation also failed: {e2}")
return None
def main():
"""
Main function to demonstrate quantization of Qwen2.5-VL models.
"""
import argparse
parser = argparse.ArgumentParser(description="Quantize Qwen2.5-VL models")
parser.add_argument("--model_id", type=str, required=True,
help="Model ID to quantize (e.g., 'huihui-ai/Huihui-Fara-7B-abliterated')")
parser.add_argument("--method", type=str, choices=["GPTQ", "AWQ", "FP8"],
default="GPTQ", help="Quantization method to use")
parser.add_argument("--output_dir", type=str, default=None,
help="Output directory for quantized model")
parser.add_argument("--dataset_id", type=str, default="lmms-lab/flickr30k",
help="Dataset for calibration (default: lmms-lab/flickr30k)")
parser.add_argument("--scheme", type=str, default="W4A16",
help="Quantization scheme (e.g., W4A16, W8A8)")
parser.add_argument("--num_samples", type=int, default=128,
help="Number of calibration samples")
args = parser.parse_args()
print(f"Starting quantization of {args.model_id} using {args.method}")
try:
# Quantize the model
quantized_model = quantize_qwen2_5_vl_model(
model_id=args.model_id,
quantization_method=args.method,
output_dir=args.output_dir,
dataset_id=args.dataset_id,
num_calibration_samples=args.num_samples,
scheme=args.scheme
)
# Test the model
# Load the processor again to test
processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True)
test_quantized_model(quantized_model, processor)
print(f"Successfully quantized {args.model_id} with {args.method}")
except Exception as e:
print(f"Quantization failed: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()