Spaces:
Running
Running
| #!/usr/bin/env python | |
| """ | |
| Specialized script for quantizing Qwen2.5-VL models with sequential onloading | |
| Handles quantization of Qwen2_5_VLForConditionalGeneration models properly | |
| """ | |
| import base64 | |
| from io import BytesIO | |
| from typing import Optional, Union, Dict, Any | |
| import torch | |
| from datasets import load_dataset | |
| from qwen_vl_utils import process_vision_info | |
| from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, AutoTokenizer | |
| from llmcompressor import oneshot | |
| from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier | |
| from llmcompressor.modifiers.awq import AWQModifier, AWQMapping | |
| from llmcompressor.utils import dispatch_for_generation | |
| def create_qwen2_5_vl_data_collator(): | |
| """Create a data collator for Qwen2.5-VL models that handles multimodal inputs.""" | |
| def data_collator(batch): | |
| assert len(batch) == 1 | |
| return {key: torch.tensor(value) if isinstance(value, (list, int, float)) else value | |
| for key, value in batch[0].items()} | |
| return data_collator | |
| def create_qwen2_5_vl_preprocessing_fn(processor, max_sequence_length: int = 2048): | |
| """Create a preprocessing function for Qwen2.5-VL datasets.""" | |
| def preprocess_and_tokenize(example): | |
| # Handle different image formats | |
| if 'image' in example: | |
| # Process image | |
| if hasattr(example['image'], 'save'): | |
| # PIL Image object | |
| buffered = BytesIO() | |
| example["image"].save(buffered, format="PNG") | |
| encoded_image = base64.b64encode(buffered.getvalue()) | |
| encoded_image_text = encoded_image.decode("utf-8") | |
| base64_qwen = f"data:image;base64,{encoded_image_text}" | |
| else: | |
| # Already a string or other format | |
| base64_qwen = str(example["image"]) | |
| else: | |
| # If there's no image field, try 'img' or similar | |
| img_key = None | |
| for key in example.keys(): | |
| if 'image' in key.lower() or 'img' in key.lower(): | |
| img_key = key | |
| break | |
| if img_key: | |
| if hasattr(example[img_key], 'save'): | |
| buffered = BytesIO() | |
| example[img_key].save(buffered, format="PNG") | |
| encoded_image = base64.b64encode(buffered.getvalue()) | |
| encoded_image_text = encoded_image.decode("utf-8") | |
| base64_qwen = f"data:image;base64,{encoded_image_text}" | |
| else: | |
| base64_qwen = str(example[img_key]) | |
| else: | |
| # If no image, create a simple text-only example | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": example.get('text', example.get('content', 'What can you tell me about this?'))}, | |
| ], | |
| } | |
| ] | |
| text = processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| return processor( | |
| text=[text], | |
| padding=False, | |
| max_length=max_sequence_length, | |
| truncation=True, | |
| ) | |
| # Create message with image | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": base64_qwen}, | |
| {"type": "text", "text": "What does the image show?"}, | |
| ], | |
| } | |
| ] | |
| text = processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| # tokenize | |
| return processor( | |
| text=[text], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=False, | |
| max_length=max_sequence_length, | |
| truncation=True, | |
| ) | |
| return preprocess_and_tokenize | |
| def get_qwen2_5_vl_quantization_recipe(method: str, scheme: str = "W4A16"): | |
| """ | |
| Creates the appropriate quantization recipe for Qwen2.5-VL models. | |
| Args: | |
| method: Quantization method ("GPTQ", "AWQ", or "FP8") | |
| scheme: Quantization scheme (e.g., "W4A16", "W8A8", "FP8") | |
| Returns: | |
| List of modifiers for the quantization recipe | |
| """ | |
| if method == "GPTQ": | |
| return [ | |
| GPTQModifier( | |
| targets="Linear", | |
| scheme=scheme, | |
| ignore=["lm_head", "re:visual.*", "re:model.visual.*"], | |
| sequential_targets=["Qwen2_5_VLDecoderLayer"], # This is key for the architecture | |
| ), | |
| ] | |
| elif method == "AWQ": | |
| # Create AWQ mappings for Qwen2.5-VL architecture | |
| mappings = [ | |
| AWQMapping( | |
| "re:.*input_layernorm", ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"] | |
| ), | |
| AWQMapping("re:.*v_proj", ["re:.*o_proj"]), | |
| AWQMapping( | |
| "re:.*post_attention_layernorm", ["re:.*gate_proj", "re:.*up_proj"] | |
| ), | |
| AWQMapping("re:.*up_proj", ["re:.*down_proj"]), | |
| ] | |
| return [ | |
| AWQModifier( | |
| ignore=["lm_head", "re:visual.*", "re:model.visual.*"], | |
| scheme="W4A16_ASYM" if scheme == "W4A16" else scheme, | |
| targets=["Linear"], | |
| mappings=mappings, | |
| ), | |
| ] | |
| elif method == "FP8": | |
| return [ | |
| QuantizationModifier( | |
| scheme="FP8", | |
| targets="Linear", | |
| ignore=["lm_head", "re:visual.*", "re:model.visual.*"] | |
| ) | |
| ] | |
| else: | |
| raise ValueError(f"Unsupported quantization method: {method}") | |
| def quantize_qwen2_5_vl_model( | |
| model_id: str, | |
| quantization_method: str, | |
| output_dir: Optional[str] = None, | |
| dataset_id: str = "lmms-lab/flickr30k", | |
| dataset_split: str = "test[:512]", | |
| num_calibration_samples: int = 512, | |
| max_sequence_length: int = 2048, | |
| scheme: str = "W4A16", | |
| trust_remote_code: bool = True, | |
| ): | |
| """ | |
| Quantizes a Qwen2.5-VL model with proper architecture handling and sequential onloading. | |
| Args: | |
| model_id: Hugging Face model ID to quantize | |
| quantization_method: Method to use ("GPTQ", "AWQ", or "FP8") | |
| output_dir: Directory to save the quantized model | |
| dataset_id: Dataset ID for calibration | |
| dataset_split: Dataset split for calibration | |
| num_calibration_samples: Number of samples to use for calibration | |
| max_sequence_length: Maximum sequence length for processing | |
| scheme: Quantization scheme (e.g., "W4A16", "W8A8") | |
| trust_remote_code: Whether to trust remote code in model loading | |
| Returns: | |
| Quantized model | |
| """ | |
| print(f"Loading model: {model_id}") | |
| # Handle different device scenarios properly | |
| if torch.cuda.is_available(): | |
| try: | |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16, # Use float16 to save memory on GPU | |
| device_map="auto", # Auto device mapping for memory efficiency | |
| trust_remote_code=trust_remote_code | |
| ) | |
| except RuntimeError as e: | |
| if "out of memory" in str(e).lower() or "offload_dir" in str(e): | |
| print(f"Memory issue detected, using offloading: {e}") | |
| import tempfile | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| offload_folder=temp_dir, | |
| max_memory={0: "24GB", "cpu": "48GB"}, | |
| trust_remote_code=trust_remote_code | |
| ) | |
| else: | |
| raise | |
| else: | |
| # CPU only | |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float32, # Use float32 on CPU | |
| device_map="cpu", | |
| trust_remote_code=trust_remote_code | |
| ) | |
| print(f"Loading processor for: {model_id}") | |
| processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=trust_remote_code) | |
| # If output directory not specified, create one based on model and method | |
| if not output_dir: | |
| model_name = model_id.rstrip("/").split("/")[-1] | |
| output_dir = f"{model_name}-{scheme.replace(':', '-')}-{quantization_method}" | |
| print(f"Output directory: {output_dir}") | |
| # Load dataset and preprocess | |
| print(f"Loading dataset: {dataset_id}") | |
| try: | |
| ds = load_dataset(dataset_id, split=dataset_split) | |
| except Exception as e: | |
| print(f"Failed to load {dataset_id}, trying alternative text-only dataset: {e}") | |
| # If the image dataset fails, try a text-only dataset | |
| ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:512]") | |
| # We'll need to adjust preprocessing for text-only data | |
| ds = ds.shuffle(seed=42) | |
| # Apply preprocessing | |
| preprocess_fn = create_qwen2_5_vl_preprocessing_fn(processor, max_sequence_length) | |
| try: | |
| ds = ds.map(preprocess_fn, remove_columns=ds.column_names if hasattr(ds, 'column_names') else []) | |
| except Exception as e: | |
| print(f"Preprocessing failed: {e}") | |
| print("Trying simpler preprocessing with text-only data...") | |
| # Fallback: use text-only preprocessing | |
| def text_only_preprocess(example): | |
| text = example.get('text', example.get('content', str(example))) | |
| if not isinstance(text, str): | |
| text = str(text) | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": text[:500] + "..." if len(text) > 500 else text}, # Limit length | |
| ], | |
| } | |
| ] | |
| prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| return processor(text=[prompt], padding=False, max_length=max_sequence_length, truncation=True) | |
| ds = ds.map(text_only_preprocess, remove_columns=ds.column_names if hasattr(ds, 'column_names') else []) | |
| # Define data collator | |
| data_collator = create_qwen2_5_vl_data_collator() | |
| # Create recipe | |
| recipe = get_qwen2_5_vl_quantization_recipe(quantization_method, scheme) | |
| print(f"Starting quantization with method: {quantization_method}") | |
| print(f"Using recipe: {recipe}") | |
| # Perform oneshot quantization with sequential targets and proper handling | |
| oneshot( | |
| model=model, | |
| tokenizer=processor, # Use processor as tokenizer for Qwen2.5-VL | |
| dataset=ds, | |
| recipe=recipe, | |
| max_seq_length=max_sequence_length, | |
| num_calibration_samples=num_calibration_samples, | |
| trust_remote_code_model=trust_remote_code, | |
| data_collator=data_collator, | |
| # Use sequential onloading for memory efficiency | |
| sequential_targets=["Qwen2_5_VLDecoderLayer"], | |
| save_compressed=True, | |
| output_dir=output_dir, | |
| ) | |
| print(f"Quantization completed! Model saved to: {output_dir}") | |
| # Save the processor as well | |
| processor.save_pretrained(output_dir) | |
| return model | |
| def test_quantized_model(model, processor, max_sequence_length: int = 2048): | |
| """ | |
| Tests the quantized model with a sample generation. | |
| """ | |
| print("========== SAMPLE GENERATION ==============") | |
| try: | |
| dispatch_for_generation(model) | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "image", | |
| "image": "http://images.cocodataset.org/train2017/000000231895.jpg", | |
| }, | |
| {"type": "text", "text": "Please describe the animal in this image\n"}, | |
| ], | |
| } | |
| ] | |
| prompt = processor.apply_chat_template(messages, add_generation_prompt=True) | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| inputs = processor( | |
| text=[prompt], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=False, | |
| max_length=max_sequence_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ).to(model.device) | |
| output = model.generate(**inputs, max_new_tokens=100) | |
| result = processor.decode(output[0], skip_special_tokens=True) | |
| print(result) | |
| print("==========================================") | |
| return result | |
| except Exception as e: | |
| print(f"Test generation failed: {e}") | |
| print("Trying text-only generation...") | |
| # Try with text-only | |
| try: | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": "Hello, how are you today?"}, | |
| ], | |
| } | |
| ] | |
| prompt = processor.apply_chat_template(messages, add_generation_prompt=True) | |
| inputs = processor( | |
| text=[prompt], | |
| padding=False, | |
| max_length=max_sequence_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ).to(model.device) | |
| output = model.generate(**inputs, max_new_tokens=50) | |
| result = processor.decode(output[0], skip_special_tokens=True) | |
| print(result) | |
| print("==========================================") | |
| return result | |
| except Exception as e2: | |
| print(f"Text-only generation also failed: {e2}") | |
| return None | |
| def main(): | |
| """ | |
| Main function to demonstrate quantization of Qwen2.5-VL models. | |
| """ | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Quantize Qwen2.5-VL models") | |
| parser.add_argument("--model_id", type=str, required=True, | |
| help="Model ID to quantize (e.g., 'huihui-ai/Huihui-Fara-7B-abliterated')") | |
| parser.add_argument("--method", type=str, choices=["GPTQ", "AWQ", "FP8"], | |
| default="GPTQ", help="Quantization method to use") | |
| parser.add_argument("--output_dir", type=str, default=None, | |
| help="Output directory for quantized model") | |
| parser.add_argument("--dataset_id", type=str, default="lmms-lab/flickr30k", | |
| help="Dataset for calibration (default: lmms-lab/flickr30k)") | |
| parser.add_argument("--scheme", type=str, default="W4A16", | |
| help="Quantization scheme (e.g., W4A16, W8A8)") | |
| parser.add_argument("--num_samples", type=int, default=128, | |
| help="Number of calibration samples") | |
| args = parser.parse_args() | |
| print(f"Starting quantization of {args.model_id} using {args.method}") | |
| try: | |
| # Quantize the model | |
| quantized_model = quantize_qwen2_5_vl_model( | |
| model_id=args.model_id, | |
| quantization_method=args.method, | |
| output_dir=args.output_dir, | |
| dataset_id=args.dataset_id, | |
| num_calibration_samples=args.num_samples, | |
| scheme=args.scheme | |
| ) | |
| # Test the model | |
| # Load the processor again to test | |
| processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True) | |
| test_quantized_model(quantized_model, processor) | |
| print(f"Successfully quantized {args.model_id} with {args.method}") | |
| except Exception as e: | |
| print(f"Quantization failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| if __name__ == "__main__": | |
| main() |