bluenevus's picture
Update app.py
1ab6fac verified
raw
history blame
3.09 kB
import gradio as gr
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, AutoConfig
from huggingface_hub import HfApi, login
def quantize_model(
model_id: str,
hf_token: str,
repo_name: str,
progress=gr.Progress(track_tqdm=True)
):
try:
# Validate credentials first
login(token=hf_token, add_to_git_credential=True)
api = HfApi(token=hf_token)
# Check model accessibility
try:
api.model_info(model_id)
except Exception as e:
raise ValueError(f"Model access error: {str(e)}. Check:\n1. Token permissions\n2. Model existence\n3. Accept model terms at https://huggingface.co/{model_id}")
# Load config with proper auth
config = AutoConfig.from_pretrained(
model_id,
token=hf_token,
trust_remote_code=True
)
# Handle Llama 3 rope_scaling
if hasattr(config, 'rope_scaling') and isinstance(config.rope_scaling, dict):
config.rope_scaling = {
"type": config.rope_scaling.get("rope_type", "linear"),
"factor": config.rope_scaling.get("factor", 1.0)
}
# Load model with validated credentials
model = AutoAWQForCausalLM.from_pretrained(
model_id,
config=config,
token=hf_token,
trust_remote_code=True,
device_map="auto"
)
# Load tokenizer with same credentials
tokenizer = AutoTokenizer.from_pretrained(
model_id,
token=hf_token,
trust_remote_code=True
)
# Quantize with auto-detected settings
model.quantize(tokenizer, quant_config={
"zero_point": True,
"q_group_size": 128,
"w_bit": 4,
"version": "GEMM" if "llama" in model_id.lower() else "GEMV"
})
# Save and push
save_path = f"{model_id.split('/')[-1]}-awq"
model.save_quantized(save_path)
model.push_to_hub(repo_name, token=hf_token)
return f"βœ… Success!\nSaved: {save_path}\nPushed to: {repo_name}"
except Exception as e:
return f"❌ Critical Error:\n{str(e)}"
with gr.Blocks() as app:
gr.Markdown("## πŸ” Secure AutoAWQ Quantizer")
with gr.Row():
model_id = gr.Textbox(label="Model ID",
placeholder="meta-llama/Meta-Llama-3-8B-Instruct",
info="Must have access rights")
hf_token = gr.Textbox(label="HF Token",
type="password",
info="Required for gated models")
repo_name = gr.Textbox(label="Destination Repo",
info="Format: username/repo-name")
go_btn = gr.Button("Start Quantization", variant="primary")
output = gr.Markdown()
go_btn.click(
quantize_model,
inputs=[model_id, hf_token, repo_name],
outputs=output
)
app.launch()