Spaces:
Running
Running
File size: 9,153 Bytes
0e528f4 9bf1d31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')))
import gradio as gr
import logging
from typing import Tuple, Dict, Any
from src.utilities.resources import ResourceManager
from src.utilities.push_to_hub import push_to_hub
from src.optimizations.onnx_conversion import convert_to_onnx
from src.optimizations.quantize import quantize_onnx_model
from src.handlers import get_model_handler, TASK_CONFIGS
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
import json
def process_model(
model_name: str,
task: str,
quantization_type: str,
enable_onnx: bool,
onnx_quantization: str,
hf_token: str,
repo_name: str,
test_text: str
) -> Tuple[Dict[str, Any], str, Dict[str, Any]]:
try:
resource_manager = ResourceManager()
status_updates = []
status = {
"status": "Processing",
"progress": 0,
"current_step": "Initializing",
}
metrics = {}
if not model_name or not hf_token or not repo_name:
return (
{"status": "Error", "progress": 0, "current_step": "Validation Failed"},
"Model name, HuggingFace token, and repository name are required.",
metrics
)
status["progress"] = 0.2
status["current_step"] = "Initialization"
status_updates.append("Initialization complete")
quantized_model_path = None
if quantization_type != "None":
status.update({"progress": 0.4, "current_step": "Quantization"})
status_updates.append(f"Applying {quantization_type} quantization")
if not test_text:
test_text = TASK_CONFIGS[task]["example_text"]
try:
handler = get_model_handler(task, model_name, quantization_type, test_text)
quantized_model = handler.compare()
metrics = handler.get_metrics()
metrics = json.loads(json.dumps(metrics))
quantized_model_path = str(resource_manager.temp_dirs["quantized"] / "model")
quantized_model.save_pretrained(quantized_model_path)
status_updates.append("Quantization completed successfully")
except Exception as e:
logger.error(f"Quantization error: {str(e)}", exc_info=True)
return (
{"status": "Error", "progress": 0.4, "current_step": "Quantization Failed"},
f"Quantization failed: {str(e)}",
metrics
)
if enable_onnx:
status.update({"progress": 0.6, "current_step": "ONNX Conversion"})
status_updates.append("Converting to ONNX format")
try:
output_dir = str(resource_manager.temp_dirs["onnx"])
onnx_result = convert_to_onnx(model_name, task, output_dir)
if onnx_result is None:
return (
{"status": "Error", "progress": 0.6, "current_step": "ONNX Conversion Failed"},
"ONNX conversion failed.",
metrics
)
if onnx_quantization != "None":
status_updates.append(f"Applying {onnx_quantization} ONNX quantization")
quantize_onnx_model(output_dir, onnx_quantization)
status.update({"progress": 0.8, "current_step": "Pushing ONNX Model"})
status_updates.append("Pushing ONNX model to Hub")
result, push_message = push_to_hub(
local_path=output_dir,
repo_name=f"{repo_name}-optimized",
hf_token=hf_token,
tags=["onnx", "optimum", task],
)
status_updates.append(push_message)
except Exception as e:
logger.error(f"ONNX error: {str(e)}", exc_info=True)
return (
{"status": "Error", "progress": 0.6, "current_step": "ONNX Processing Failed"},
f"ONNX processing failed: {str(e)}",
metrics
)
if quantization_type != "None" and quantized_model_path:
status.update({"progress": 0.9, "current_step": "Pushing Quantized Model"})
status_updates.append("Pushing quantized model to Hub")
result, push_message = push_to_hub(
local_path=quantized_model_path,
repo_name=f"{repo_name}-optimized",
hf_token=hf_token,
tags=["quantized", task, quantization_type],
)
status_updates.append(push_message)
status.update({"progress": 1.0, "status": "Complete", "current_step": "Completed"})
cleanup_message = resource_manager.cleanup_temp_files()
status_updates.append(cleanup_message)
return (
status,
"\n".join(status_updates),
metrics
)
except Exception as e:
logger.error(f"Error during processing: {str(e)}", exc_info=True)
return (
{"status": "Error", "progress": 0, "current_step": "Process Failed"},
f"An error occurred: {str(e)}",
metrics
)
# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft()) as app:
gr.Markdown("""
# 🤗 Model Conversion Hub
Convert and optimize your Hugging Face models with quantization and ONNX support.
""")
with gr.Row():
with gr.Column(scale=2):
model_name = gr.Textbox(label="Model Name", placeholder="e.g., bert-base-uncased")
task = gr.Dropdown(choices=list(TASK_CONFIGS.keys()), label="Task", value="text_classification")
with gr.Group():
gr.Markdown("### Quantization Settings")
quantization_type = gr.Dropdown(choices=["None", "4-bit", "8-bit", "16-bit-float"], label="Quantization Type", value="None")
test_text = gr.Textbox(label="Test Text", placeholder="Enter text for model evaluation", lines=3, visible=False)
with gr.Group():
gr.Markdown("### ONNX Settings")
enable_onnx = gr.Checkbox(label="Enable ONNX Conversion")
with gr.Group(visible=False) as onnx_group:
onnx_quantization = gr.Dropdown(choices=["None", "8-bit", "16-bit-int", "16-bit-float"], label="ONNX Quantization", value="None")
with gr.Group():
gr.Markdown("### HuggingFace Settings")
hf_token = gr.Textbox(label="HuggingFace Token (Required)", type="password")
repo_name = gr.Textbox(label="Repository Name")
with gr.Column(scale=1):
status_output = gr.JSON(label="Status", value={"status": "Ready", "progress": 0, "current_step": "Waiting"})
message_output = gr.Markdown(label="Progress Messages")
gr.Markdown("### Metrics")
with gr.Group():
metrics_output = gr.JSON(
value={
"model_sizes": {"original": 0.0, "quantized": 0.0},
"inference_times": {"original": 0.0, "quantized": 0.0},
"comparison_metrics": {}
},
show_label=True
)
memory_info = gr.JSON(label="Resource Usage")
convert_btn = gr.Button("🚀 Start Conversion", variant="primary")
with gr.Accordion("ℹ️ Help", open=False):
gr.Markdown("""
### Quick Guide
1. Enter your model name and HuggingFace token.
2. Select the appropriate task.
3. Choose optimization options.
4. Click Start Conversion.
### Tips
- Ensure sufficient system resources.
- Use test text to validate conversions.
""")
def update_memory_info():
resource_manager = ResourceManager()
return resource_manager.get_memory_info()
quantization_type.change(lambda x: gr.update(visible=x != "None"), inputs=[quantization_type], outputs=[test_text])
task.change(lambda x: gr.update(value=TASK_CONFIGS[x]["example_text"]), inputs=[task], outputs=[test_text])
enable_onnx.change(lambda x: gr.update(visible=x), inputs=[enable_onnx], outputs=[onnx_group])
convert_btn.click(
process_model,
inputs=[model_name, task, quantization_type, enable_onnx, onnx_quantization, hf_token, repo_name, test_text],
outputs=[status_output, message_output, metrics_output]
)
app.load(update_memory_info, outputs=[memory_info], every=30)
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=7860, share=True, debug=True)
|