|
import gradio as gr |
|
from awq import AutoAWQForCausalLM |
|
from transformers import AutoTokenizer, AutoConfig |
|
from huggingface_hub import HfApi, login |
|
|
|
def quantize_model( |
|
model_id: str, |
|
hf_token: str, |
|
repo_name: str, |
|
progress=gr.Progress(track_tqdm=True) |
|
): |
|
try: |
|
|
|
login(token=hf_token, add_to_git_credential=True) |
|
api = HfApi(token=hf_token) |
|
|
|
|
|
try: |
|
api.model_info(model_id) |
|
except Exception as e: |
|
raise ValueError(f"Model access error: {str(e)}. Check:\n1. Token permissions\n2. Model existence\n3. Accept model terms at https://huggingface.co/{model_id}") |
|
|
|
|
|
config = AutoConfig.from_pretrained( |
|
model_id, |
|
token=hf_token, |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
if hasattr(config, 'rope_scaling') and isinstance(config.rope_scaling, dict): |
|
config.rope_scaling = { |
|
"type": config.rope_scaling.get("rope_type", "linear"), |
|
"factor": config.rope_scaling.get("factor", 1.0) |
|
} |
|
|
|
|
|
model = AutoAWQForCausalLM.from_pretrained( |
|
model_id, |
|
config=config, |
|
token=hf_token, |
|
trust_remote_code=True, |
|
device_map="auto" |
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_id, |
|
token=hf_token, |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
model.quantize(tokenizer, quant_config={ |
|
"zero_point": True, |
|
"q_group_size": 128, |
|
"w_bit": 4, |
|
"version": "GEMM" if "llama" in model_id.lower() else "GEMV" |
|
}) |
|
|
|
|
|
save_path = f"{model_id.split('/')[-1]}-awq" |
|
model.save_quantized(save_path) |
|
model.push_to_hub(repo_name, token=hf_token) |
|
|
|
return f"β
Success!\nSaved: {save_path}\nPushed to: {repo_name}" |
|
|
|
except Exception as e: |
|
return f"β Critical Error:\n{str(e)}" |
|
|
|
with gr.Blocks() as app: |
|
gr.Markdown("## π Secure AutoAWQ Quantizer") |
|
|
|
with gr.Row(): |
|
model_id = gr.Textbox(label="Model ID", |
|
placeholder="meta-llama/Meta-Llama-3-8B-Instruct", |
|
info="Must have access rights") |
|
hf_token = gr.Textbox(label="HF Token", |
|
type="password", |
|
info="Required for gated models") |
|
repo_name = gr.Textbox(label="Destination Repo", |
|
info="Format: username/repo-name") |
|
|
|
go_btn = gr.Button("Start Quantization", variant="primary") |
|
output = gr.Markdown() |
|
|
|
go_btn.click( |
|
quantize_model, |
|
inputs=[model_id, hf_token, repo_name], |
|
outputs=output |
|
) |
|
|
|
app.launch() |