Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from huggingface_hub import HfApi, ModelCard, whoami | |
| from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
| from llmcompressor import oneshot | |
| from llmcompressor.modifiers.quantization import QuantizationModifier, GPTQModifier | |
| from llmcompressor.modifiers.awq import AWQModifier, AWQMapping | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| Qwen2_5_VLForConditionalGeneration, | |
| AutoConfig, | |
| AutoModel | |
| ) | |
| import torch | |
| import time | |
| import threading | |
| from typing import Callable, Optional | |
| # --- Helper Functions --- | |
| class ProgressTracker: | |
| """Class to track progress and send updates to the UI""" | |
| def __init__(self): | |
| self.current_stage = 0 | |
| self.total_stages = 5 # Load model, Get recipe, Run compression, Create repo, Create model card | |
| self.stage_descriptions = [ | |
| "Loading model and tokenizer...", | |
| "Preparing quantization recipe...", | |
| "Running quantization compression...", | |
| "Creating Hugging Face repository and uploading...", | |
| "Generating model card..." | |
| ] | |
| self.progress = 0.0 | |
| self.status = "" | |
| self.lock = threading.Lock() | |
| def update_stage(self, stage_idx: int, description: str = ""): | |
| with self.lock: | |
| self.current_stage = stage_idx | |
| self.status = description or self.stage_descriptions[stage_idx] | |
| # Calculate progress (each stage is 20% of total) | |
| self.progress = min(100.0, (stage_idx / self.total_stages) * 100) | |
| def update_progress(self, current: float, total: float, description: str = ""): | |
| with self.lock: | |
| # Calculate progress within the current stage | |
| stage_progress = (current / total) * (100.0 / self.total_stages) | |
| self.progress = min(100.0, ((self.current_stage / self.total_stages) * 100) + stage_progress) | |
| if description: | |
| self.status = description | |
| def get_state(self): | |
| with self.lock: | |
| return { | |
| "progress": self.progress, | |
| "status": self.status, | |
| "current_stage": self.current_stage + 1, # 1-indexed for display | |
| "total_stages": self.total_stages | |
| } | |
| def get_quantization_recipe(method, model_architecture): | |
| """ | |
| Returns the appropriate llm-compressor recipe based on the selected method. | |
| Updated to support Qwen2_5_VLForConditionalGeneration architecture and more quantization methods. | |
| """ | |
| if method == "AWQ": | |
| if model_architecture not in ["LlamaForCausalLM", "Qwen2_5_VLForConditionalGeneration"]: | |
| raise ValueError( | |
| f"AWQ quantization is only supported for LlamaForCausalLM and Qwen2_5_VLForConditionalGeneration architectures, got {model_architecture}" | |
| ) | |
| # AWQ is fundamentally incompatible with Qwen2.5-VL models due to conflicts with | |
| # the complex 3D rotary positional embedding system used for multimodal processing | |
| if model_architecture == "Qwen2_5_VLForConditionalGeneration": | |
| raise ValueError( | |
| f"AWQ quantization is not compatible with {model_architecture} architecture " | |
| "due to fundamental conflicts with complex 3D rotary positional embeddings. " | |
| "This quantization method modifies weights in a way that breaks the multimodal " | |
| "positional encoding system. Please use GPTQ, W4A16, W8A16, W8A8_INT8, W8A8_FP8, or FP8 methods instead." | |
| ) | |
| else: # LlamaForCausalLM and other supported architectures | |
| # Create AWQ mappings for Llama models | |
| mappings = [ | |
| AWQMapping( | |
| "re:.*input_layernorm", ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"] | |
| ), | |
| AWQMapping("re:.*v_proj", ["re:.*o_proj"]), | |
| AWQMapping( | |
| "re:.*post_attention_layernorm", ["re:.*gate_proj", "re:.*up_proj"] | |
| ), | |
| AWQMapping("re:.*up_proj", ["re:.*down_proj"]), | |
| ] | |
| return [ | |
| AWQModifier( | |
| ignore=["lm_head"], | |
| scheme="W4A16_ASYM", | |
| targets=["Linear"], | |
| mappings=mappings, | |
| ), | |
| ] | |
| elif method == "GPTQ": | |
| sequential_target_map = { | |
| "LlamaForCausalLM": "LlamaDecoderLayer", | |
| "MistralForCausalLM": "MistralDecoderLayer", | |
| "MixtralForCausalLM": "MixtralDecoderLayer", | |
| "Qwen2_5_VLForConditionalGeneration": "Qwen2_5_VLDecoderLayer", # Add Qwen2.5-VL support | |
| } | |
| sequential_target = sequential_target_map.get(model_architecture) | |
| if sequential_target is None: | |
| raise ValueError( | |
| f"GPTQ quantization is not supported for {model_architecture} architecture. " | |
| "Supported architectures are: " | |
| f"{', '.join(sequential_target_map.keys())}" | |
| ) | |
| if model_architecture == "Qwen2_5_VLForConditionalGeneration": | |
| return [ | |
| GPTQModifier( | |
| targets="Linear", | |
| scheme="W4A16", | |
| sequential_targets=[sequential_target], | |
| ignore=["lm_head", "re:visual.*", "re:model.visual.*"], # Ignore visual components | |
| ), | |
| ] | |
| else: | |
| return [ | |
| GPTQModifier( | |
| targets="Linear", | |
| scheme="W4A16", | |
| sequential_targets=[sequential_target], | |
| ignore=["re:.*lm_head"], | |
| ), | |
| ] | |
| elif method in ["W4A16", "W8A16", "W8A8_INT8", "W8A8_FP8", "FP8"]: | |
| # All these methods use the QuantizationModifier | |
| if model_architecture not in ["LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "Qwen2_5_VLForConditionalGeneration"]: | |
| raise ValueError( | |
| f"Quantization method {method} is not supported for {model_architecture} architecture. " | |
| "Supported architectures are: LlamaForCausalLM, MistralForCausalLM, MixtralForCausalLM, Qwen2_5_VLForConditionalGeneration" | |
| ) | |
| # Map method names to actual schemes (correct names for llmcompressor) | |
| scheme_map = { | |
| "W4A16": "W4A16", | |
| "W8A16": "W8A16", | |
| "W8A8_INT8": "W8A8", # Use the correct scheme name | |
| "W8A8_FP8": "W8A8", # Both use W8A8 but with different dtypes | |
| "FP8": "FP8" | |
| } | |
| ignore_layers = ["lm_head"] | |
| if "Mixtral" in model_architecture: | |
| ignore_layers.append("re:.*block_sparse_moe.gate") | |
| elif "Qwen2_5_VL" in model_architecture: | |
| ignore_layers.extend(["re:visual.*", "re:model.visual.*"]) # Ignore visual components for Qwen2.5-VL | |
| # For methods that support sequential onloading for Qwen2.5-VL, we use GPTQModifier with sequential_targets | |
| if model_architecture == "Qwen2_5_VLForConditionalGeneration" and method in ["W4A16"]: | |
| return [ | |
| GPTQModifier( | |
| targets="Linear", | |
| scheme=scheme_map[method], | |
| sequential_targets=["Qwen2_5_VLDecoderLayer"], # Sequential onloading for memory efficiency | |
| ignore=ignore_layers, | |
| ), | |
| ] | |
| else: | |
| return [QuantizationModifier( | |
| scheme=scheme_map[method], | |
| targets="Linear", | |
| ignore=ignore_layers | |
| )] | |
| elif method == "SmoothQuant": | |
| if model_architecture not in ["LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM"]: | |
| raise ValueError( | |
| f"SmoothQuant is not supported for {model_architecture} architecture. " | |
| "Supported architectures are: LlamaForCausalLM, MistralForCausalLM, MixtralForCausalLM" | |
| ) | |
| ignore_layers = ["lm_head"] | |
| if "Mixtral" in model_architecture: | |
| ignore_layers.append("re:.*block_sparse_moe.gate") | |
| return [QuantizationModifier( | |
| scheme="W8A8", # SmoothQuant typically uses W8A8 | |
| targets="Linear", | |
| ignore=ignore_layers | |
| )] | |
| elif method == "SparseGPT": | |
| if model_architecture not in ["LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM"]: | |
| raise ValueError( | |
| f"SparseGPT is not supported for {model_architecture} architecture. " | |
| "Supported architectures are: LlamaForCausalLM, MistralForCausalLM, MixtralForCausalLM" | |
| ) | |
| ignore_layers = ["lm_head"] | |
| if "Mixtral" in model_architecture: | |
| ignore_layers.append("re:.*block_sparse_moe.gate") | |
| return [ | |
| GPTQModifier( # SparseGPT uses GPTQ algorithm with different parameters | |
| targets="Linear", | |
| scheme="W4A16", # Default scheme for sparsity | |
| ignore=ignore_layers, | |
| ) | |
| ] | |
| else: | |
| raise ValueError(f"Unsupported quantization method: {method}") | |
| def get_model_class_by_name(model_type_name): | |
| """ | |
| Returns the appropriate model class based on the user-selected model type name. | |
| """ | |
| if model_type_name == "CausalLM (standard text generation)": | |
| return AutoModelForCausalLM | |
| elif model_type_name == "Qwen2_5_VLForConditionalGeneration (Qwen2.5-VL)": | |
| from transformers import Qwen2_5_VLForConditionalGeneration | |
| return Qwen2_5_VLForConditionalGeneration | |
| elif model_type_name == "Qwen2ForCausalLM (Qwen2)": | |
| from transformers import Qwen2ForCausalLM | |
| return Qwen2ForCausalLM | |
| elif model_type_name == "LlamaForCausalLM (Llama, Llama2, Llama3)": | |
| from transformers import LlamaForCausalLM | |
| return LlamaForCausalLM | |
| elif model_type_name == "MistralForCausalLM (Mistral, Mixtral)": | |
| from transformers import MistralForCausalLM | |
| return MistralForCausalLM | |
| elif model_type_name == "GemmaForCausalLM (Gemma)": | |
| from transformers import GemmaForCausalLM | |
| return GemmaForCausalLM | |
| elif model_type_name == "Gemma2ForCausalLM (Gemma2)": | |
| from transformers import Gemma2ForCausalLM | |
| return Gemma2ForCausalLM | |
| elif model_type_name == "PhiForCausalLM (Phi, Phi2)": | |
| from transformers import PhiForCausalLM | |
| return PhiForCausalLM | |
| elif model_type_name == "Phi3ForCausalLM (Phi3)": | |
| from transformers import Phi3ForCausalLM | |
| return Phi3ForCausalLM | |
| elif model_type_name == "FalconForCausalLM (Falcon)": | |
| from transformers import FalconForCausalLM | |
| return FalconForCausalLM | |
| elif model_type_name == "MptForCausalLM (MPT)": | |
| from transformers import MptForCausalLM | |
| return MptForCausalLM | |
| elif model_type_name == "GPT2LMHeadModel (GPT2)": | |
| from transformers import GPT2LMHeadModel | |
| return GPT2LMHeadModel | |
| elif model_type_name == "GPTNeoXForCausalLM (GPT-NeoX)": | |
| from transformers import GPTNeoXForCausalLM | |
| return GPTNeoXForCausalLM | |
| elif model_type_name == "GPTJForCausalLM (GPT-J)": | |
| from transformers import GPTJForCausalLM | |
| return GPTJForCausalLM | |
| else: | |
| # Default case - should not happen if all options are handled | |
| return AutoModelForCausalLM | |
| def determine_model_class(model_id: str, token: str, manual_model_type: str = None): | |
| """ | |
| Determines the appropriate model class based on either: | |
| 1. Automatic detection from model config, or | |
| 2. User selection (if provided) | |
| """ | |
| # If user specified a manual model type and it's not auto-detect, use that | |
| if manual_model_type and manual_model_type != "Auto-detect (recommended)": | |
| return get_model_class_by_name(manual_model_type) | |
| # Otherwise, try automatic detection | |
| try: | |
| # Load the model configuration to determine the appropriate class | |
| config = AutoConfig.from_pretrained(model_id, token=token, trust_remote_code=True) | |
| # Check if model type is in the configuration | |
| if hasattr(config, 'model_type'): | |
| model_type = config.model_type.lower() | |
| # Handle different model types based on their config | |
| if model_type in ['qwen2_5_vl', 'qwen2-vl', 'qwen2vl']: | |
| from transformers import Qwen2_5_VLForConditionalGeneration | |
| return Qwen2_5_VLForConditionalGeneration | |
| elif model_type in ['qwen2', 'qwen', 'qwen2.5']: | |
| from transformers import Qwen2ForCausalLM | |
| return Qwen2ForCausalLM | |
| elif model_type in ['llama', 'llama2', 'llama3', 'llama3.1', 'llama3.2', 'llama3.3']: | |
| from transformers import LlamaForCausalLM | |
| return LlamaForCausalLM | |
| elif model_type in ['mistral', 'mixtral']: | |
| from transformers import MistralForCausalLM | |
| return MistralForCausalLM | |
| elif model_type in ['gemma', 'gemma2']: | |
| from transformers import GemmaForCausalLM, Gemma2ForCausalLM | |
| return Gemma2ForCausalLM if 'gemma2' in model_type else GemmaForCausalLM | |
| elif model_type in ['phi', 'phi2', 'phi3', 'phi3.5']: | |
| from transformers import PhiForCausalLM, Phi3ForCausalLM | |
| return Phi3ForCausalLM if 'phi3' in model_type else PhiForCausalLM | |
| elif model_type in ['falcon']: | |
| from transformers import FalconForCausalLM | |
| return FalconForCausalLM | |
| elif model_type in ['mpt']: | |
| from transformers import MptForCausalLM | |
| return MptForCausalLM | |
| elif model_type in ['gpt2', 'gpt', 'gpt_neox', 'gptj']: | |
| from transformers import GPT2LMHeadModel, GPTNeoXForCausalLM, GPTJForCausalLM | |
| if 'neox' in model_type: | |
| return GPTNeoXForCausalLM | |
| elif 'j' in model_type: | |
| return GPTJForCausalLM | |
| else: | |
| return GPT2LMHeadModel | |
| else: | |
| # Default to AutoModelForCausalLM for standard text generation models | |
| return AutoModelForCausalLM | |
| else: | |
| # If no model type is specified in config, default to AutoModelForCausalLM | |
| return AutoModelForCausalLM | |
| except Exception as e: | |
| print(f"Could not determine model class from config: {e}") | |
| return AutoModelForCausalLM # fallback to default | |
| def compress_and_upload( | |
| model_id: str, | |
| quant_method: str, | |
| model_type_selection: str, # New parameter for manual model type selection | |
| oauth_token: gr.OAuthToken | None, | |
| progress=gr.Progress() # Gradio progress tracker | |
| ): | |
| """ | |
| Compresses a model using llm-compressor and uploads it to a new HF repo. | |
| """ | |
| if not model_id: | |
| raise gr.Error("Please select a model from the search bar.") | |
| if oauth_token is None: | |
| raise gr.Error("Authentication error. Please log in to continue.") | |
| token = oauth_token.token | |
| try: | |
| # Use the provided token for all hub interactions | |
| username = whoami(token=token)["name"] | |
| # --- 1. Load Model and Tokenizer --- | |
| progress(0, desc="Stage 1/5: Loading model and tokenizer...") | |
| # Determine the appropriate model class based on the model's configuration or user selection | |
| model_class = determine_model_class(model_id, token, model_type_selection) | |
| try: | |
| # Show sub-steps during model loading | |
| progress(0.05, desc="Stage 1/5: Determining model class...") | |
| # Determine the optimal device configuration based on available resources | |
| if torch.cuda.is_available(): | |
| # If CUDA is available, use auto device mapping to distribute model across available devices | |
| model = model_class.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else "auto", | |
| device_map="auto", | |
| token=token, | |
| trust_remote_code=True | |
| ) | |
| else: | |
| # If no CUDA, load on CPU | |
| model = model_class.from_pretrained( | |
| model_id, | |
| torch_dtype="auto", | |
| device_map="cpu", | |
| token=token, | |
| trust_remote_code=True | |
| ) | |
| progress(0.15, desc="Stage 1/5: Model loaded, loading tokenizer...") | |
| except ValueError as e: | |
| if "Unrecognized configuration class" in str(e): | |
| # If automatic detection fails, fall back to AutoModel and let transformers handle it | |
| print(f"Automatic model class detection failed, falling back to AutoModel: {e}") | |
| progress(0.05, desc="Stage 1/5: Using fallback model class...") | |
| if torch.cuda.is_available(): | |
| model = AutoModel.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else "auto", | |
| device_map="auto", | |
| token=token, | |
| trust_remote_code=True | |
| ) | |
| else: | |
| model = AutoModel.from_pretrained( | |
| model_id, | |
| torch_dtype="auto", | |
| device_map="cpu", | |
| token=token, | |
| trust_remote_code=True | |
| ) | |
| progress(0.15, desc="Stage 1/5: Model loaded with fallback class...") | |
| elif "offload_dir" in str(e): | |
| # If the error mentions offload_dir, try with disk offloading | |
| print(f"Model requires offloading, trying with temporary offload directory: {e}") | |
| progress(0.05, desc="Stage 1/5: Setting up model with offloading...") | |
| import tempfile | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| model = model_class.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else "auto", | |
| device_map="auto", | |
| offload_folder=temp_dir, | |
| token=token, | |
| trust_remote_code=True | |
| ) | |
| progress(0.15, desc="Stage 1/5: Model loaded with offloading...") | |
| else: | |
| raise | |
| except RuntimeError as e: | |
| if "out of memory" in str(e).lower() or "offload_dir" in str(e): | |
| # If there's an out of memory error or offload_dir error, try memory-efficient loading | |
| print(f"Memory issue detected, trying with CPU offloading: {e}") | |
| progress(0.05, desc="Stage 1/5: Setting up memory-efficient model loading...") | |
| # Use CPU offloading to handle memory constraints | |
| import tempfile | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| model = model_class.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else "auto", | |
| device_map="auto", | |
| offload_folder=temp_dir, | |
| max_memory={0: "24GB", "cpu": "48GB"}, # Limit GPU memory usage | |
| token=token, | |
| trust_remote_code=True | |
| ) | |
| progress(0.15, desc="Stage 1/5: Model loaded with memory-efficient approach...") | |
| else: | |
| raise | |
| output_dir = f"{model_id.split('/')[-1]}-{quant_method}" | |
| # --- 2. Get Recipe --- | |
| progress(0.2, desc="Stage 2/5: Preparing quantization recipe...") | |
| if not model.config.architectures: | |
| raise gr.Error("Could not determine model architecture.") | |
| progress(0.25, desc="Stage 2/5: Analyzing model architecture...") | |
| recipe = get_quantization_recipe(quant_method, model.config.architectures[0]) | |
| progress(0.3, desc="Stage 2/5: Quantization recipe prepared!") | |
| # --- 3. Run Compression --- | |
| progress(0.35, desc="Stage 3/5: Setting up quantization dataset...") | |
| # Determine if this is a Qwen2.5-VL model to use appropriate dataset and data collator | |
| if model.config.architectures and "Qwen2_5_VLForConditionalGeneration" in model.config.architectures[0]: | |
| # Use a multimodal dataset and data collator for Qwen2.5-VL models | |
| try: | |
| from datasets import load_dataset | |
| progress(0.36, desc="Stage 3/5: Loading multimodal dataset for Qwen2.5-VL model...") | |
| # Use a small subset of flickr30k for calibration if available | |
| ds = load_dataset("lmms-lab/flickr30k", split="test[:64]") | |
| ds = ds.shuffle(seed=42) | |
| progress(0.38, desc="Stage 3/5: Dataset loaded, preparing data collator...") | |
| # Define a data collator for multimodal inputs | |
| def qwen2_5_vl_data_collator(batch): | |
| assert len(batch) == 1 | |
| return {key: torch.tensor(value) if isinstance(value, (list, int, float)) else value | |
| for key, value in batch[0].items()} | |
| progress(0.4, desc="Stage 3/5: Starting quantization process for Qwen2.5-VL model...") | |
| oneshot( | |
| model=model, | |
| dataset=ds, | |
| recipe=recipe, | |
| save_compressed=True, | |
| output_dir=output_dir, | |
| max_seq_length=2048, # Increased for multimodal models | |
| num_calibration_samples=64, | |
| data_collator=qwen2_5_vl_data_collator, | |
| ) | |
| progress(0.7, desc="Stage 3/5: Qwen2.5-VL quantization completed!") | |
| except Exception as e: | |
| print(f"Could not load multimodal dataset, falling back to text-only: {e}") | |
| progress(0.36, desc="Stage 3/5: Multimodal dataset failed, using fallback dataset...") | |
| # Fall back to text-only dataset - load it properly and pass as dataset | |
| from datasets import load_dataset | |
| fallback_ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]") | |
| progress(0.4, desc="Stage 3/5: Running quantization with fallback dataset...") | |
| oneshot( | |
| model=model, | |
| dataset=fallback_ds, | |
| recipe=recipe, | |
| save_compressed=True, | |
| output_dir=output_dir, | |
| max_seq_length=512, | |
| num_calibration_samples=64, | |
| ) | |
| progress(0.7, desc="Stage 3/5: Quantization with fallback dataset completed!") | |
| else: | |
| # For non-multimodal models, use the original approach | |
| from datasets import load_dataset | |
| progress(0.36, desc="Stage 3/5: Loading text dataset...") | |
| ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]") | |
| progress(0.4, desc="Stage 3/5: Starting quantization process for standard model...") | |
| oneshot( | |
| model=model, | |
| dataset=ds, | |
| recipe=recipe, | |
| save_compressed=True, | |
| output_dir=output_dir, | |
| max_seq_length=512, | |
| num_calibration_samples=64, | |
| ) | |
| progress(0.7, desc="Stage 3/5: Quantization completed!") | |
| # --- 4. Create Repo and Upload --- | |
| progress(0.75, desc="Stage 4/5: Preparing Hugging Face repository...") | |
| api = HfApi(token=token) | |
| repo_id = f"{username}/{output_dir}" | |
| progress(0.78, desc="Stage 4/5: Creating repository...") | |
| repo_url = api.create_repo(repo_id=repo_id, exist_ok=True) | |
| progress(0.8, desc="Stage 4/5: Uploading model files...") | |
| api.upload_folder( | |
| folder_path=output_dir, | |
| repo_id=repo_id, | |
| commit_message=f"Upload {quant_method} compressed model", | |
| ) | |
| progress(0.9, desc="Stage 4/5: Upload completed!") | |
| # --- 5. Create Model Card --- | |
| progress(0.95, desc="Stage 5/5: Generating model card...") | |
| card_content = f""" | |
| --- | |
| license: apache-2.0 | |
| base_model: {model_id} | |
| tags: | |
| - llm-compressor | |
| - quantization | |
| - {quant_method.lower()} | |
| --- | |
| # {quant_method} Compressed Model: {repo_id} | |
| This model was compressed from [`{model_id}`](https://huggingface.co/{model_id}) using the [vLLM LLM-Compressor](https://github.com/vllm-project/llm-compressor) library. | |
| This conversion was performed by the `llm-compressor-my-repo` Hugging Face Space. | |
| ## Quantization Method: {quant_method} | |
| For more details on the recipe used, refer to the `recipe.yaml` file in this repository. | |
| """ | |
| card = ModelCard(card_content) | |
| card.push_to_hub(repo_id, token=token) | |
| progress(1.0, desc="✅ All stages completed! Your compressed model is ready.") | |
| return f'<h1>✅ Success!</h1><br/>Model compressed and saved to your new repo: <a href="{repo_url}" target="_blank" style="text-decoration:underline">{repo_id}</a>' | |
| except gr.Error as e: | |
| raise e | |
| except Exception as e: | |
| error_message = str(e).replace("\n", "<br/>") | |
| return f'<h1>❌ ERROR</h1><br/><pre style="white-space:pre-wrap;">{error_message}</pre>' | |
| # --- Gradio Interface --- | |
| def build_gradio_app(): | |
| with gr.Blocks(css="footer {display: none !important;}") as demo: | |
| gr.Markdown("# LLM-Compressor My Repo") | |
| gr.Markdown( | |
| "Log in, choose a model, select a quantization method, and this Space will create a new compressed model repository on your Hugging Face profile." | |
| ) | |
| with gr.Row(): | |
| login_button = gr.LoginButton(min_width=250) # noqa: F841 | |
| gr.Markdown("### 1. Select a Model from the Hugging Face Hub") | |
| model_input = HuggingfaceHubSearch( | |
| label="Search for a Model", | |
| search_type="model", | |
| ) | |
| gr.Markdown("### 2. Choose a Quantization Method") | |
| quant_method_dropdown = gr.Dropdown( | |
| ["W4A16", "W8A16", "W8A8_INT8", "W8A8_FP8", "GPTQ", "FP8", "AWQ", "SmoothQuant", "SparseGPT"], | |
| label="Quantization Method", | |
| value="W4A16" | |
| ) | |
| gr.Markdown("### 3. Model Type (Auto-detected, but you can override if needed)") | |
| model_type_dropdown = gr.Dropdown( | |
| choices=[ | |
| "Auto-detect (recommended)", | |
| "CausalLM (standard text generation)", | |
| "Qwen2_5_VLForConditionalGeneration (Qwen2.5-VL)", | |
| "Qwen2ForCausalLM (Qwen2)", | |
| "LlamaForCausalLM (Llama, Llama2, Llama3)", | |
| "MistralForCausalLM (Mistral, Mixtral)", | |
| "GemmaForCausalLM (Gemma)", | |
| "Gemma2ForCausalLM (Gemma2)", | |
| "PhiForCausalLM (Phi, Phi2)", | |
| "Phi3ForCausalLM (Phi3)", | |
| "FalconForCausalLM (Falcon)", | |
| "MptForCausalLM (MPT)", | |
| "GPT2LMHeadModel (GPT2)", | |
| "GPTNeoXForCausalLM (GPT-NeoX)", | |
| "GPTJForCausalLM (GPT-J)" | |
| ], | |
| label="Model Type", | |
| value="Auto-detect (recommended)" | |
| ) | |
| compress_button = gr.Button("Compress and Create Repo", variant="primary") | |
| output_html = gr.HTML(label="Result") | |
| # Create the event handler with updates to disable button during processing | |
| btn_click = compress_button.click( | |
| fn=compress_and_upload, | |
| inputs=[model_input, quant_method_dropdown, model_type_dropdown], | |
| outputs=output_html, | |
| show_progress=True # Show built-in progress bar | |
| ) | |
| # Disable button during processing then re-enable it afterward | |
| btn_click.then( | |
| fn=lambda: gr.Button(interactive=False, value="Processing..."), | |
| inputs=[], | |
| outputs=[compress_button], | |
| queue=False | |
| ).then( | |
| fn=lambda: gr.Button(interactive=True, value="Compress and Create Repo"), | |
| inputs=[], | |
| outputs=[compress_button], | |
| queue=False | |
| ) | |
| return demo | |
| def main(): | |
| demo = build_gradio_app() | |
| demo.queue(max_size=5).launch() | |
| if __name__ == "__main__": | |
| main() | |