Spaces:
Running
Running
import sys | |
import os | |
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))) | |
import gradio as gr | |
import logging | |
from typing import Tuple, Dict, Any | |
from src.utilities.resources import ResourceManager | |
from src.utilities.push_to_hub import push_to_hub | |
from src.optimizations.onnx_conversion import convert_to_onnx | |
from src.optimizations.quantize import quantize_onnx_model | |
from src.handlers import get_model_handler, TASK_CONFIGS | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
import json | |
def process_model( | |
model_name: str, | |
task: str, | |
quantization_type: str, | |
enable_onnx: bool, | |
onnx_quantization: str, | |
hf_token: str, | |
repo_name: str, | |
test_text: str | |
) -> Tuple[Dict[str, Any], str, Dict[str, Any]]: | |
try: | |
resource_manager = ResourceManager() | |
status_updates = [] | |
status = { | |
"status": "Processing", | |
"progress": 0, | |
"current_step": "Initializing", | |
} | |
metrics = {} | |
if not model_name or not hf_token or not repo_name: | |
return ( | |
{"status": "Error", "progress": 0, "current_step": "Validation Failed"}, | |
"Model name, HuggingFace token, and repository name are required.", | |
metrics | |
) | |
status["progress"] = 0.2 | |
status["current_step"] = "Initialization" | |
status_updates.append("Initialization complete") | |
quantized_model_path = None | |
if quantization_type != "None": | |
status.update({"progress": 0.4, "current_step": "Quantization"}) | |
status_updates.append(f"Applying {quantization_type} quantization") | |
if not test_text: | |
test_text = TASK_CONFIGS[task]["example_text"] | |
try: | |
handler = get_model_handler(task, model_name, quantization_type, test_text) | |
quantized_model = handler.compare() | |
metrics = handler.get_metrics() | |
metrics = json.loads(json.dumps(metrics)) | |
quantized_model_path = str(resource_manager.temp_dirs["quantized"] / "model") | |
quantized_model.save_pretrained(quantized_model_path) | |
status_updates.append("Quantization completed successfully") | |
except Exception as e: | |
logger.error(f"Quantization error: {str(e)}", exc_info=True) | |
return ( | |
{"status": "Error", "progress": 0.4, "current_step": "Quantization Failed"}, | |
f"Quantization failed: {str(e)}", | |
metrics | |
) | |
if enable_onnx: | |
status.update({"progress": 0.6, "current_step": "ONNX Conversion"}) | |
status_updates.append("Converting to ONNX format") | |
try: | |
output_dir = str(resource_manager.temp_dirs["onnx"]) | |
onnx_result = convert_to_onnx(model_name, task, output_dir) | |
if onnx_result is None: | |
return ( | |
{"status": "Error", "progress": 0.6, "current_step": "ONNX Conversion Failed"}, | |
"ONNX conversion failed.", | |
metrics | |
) | |
if onnx_quantization != "None": | |
status_updates.append(f"Applying {onnx_quantization} ONNX quantization") | |
quantize_onnx_model(output_dir, onnx_quantization) | |
status.update({"progress": 0.8, "current_step": "Pushing ONNX Model"}) | |
status_updates.append("Pushing ONNX model to Hub") | |
result, push_message = push_to_hub( | |
local_path=output_dir, | |
repo_name=f"{repo_name}-optimized", | |
hf_token=hf_token, | |
tags=["onnx", "optimum", task], | |
) | |
status_updates.append(push_message) | |
except Exception as e: | |
logger.error(f"ONNX error: {str(e)}", exc_info=True) | |
return ( | |
{"status": "Error", "progress": 0.6, "current_step": "ONNX Processing Failed"}, | |
f"ONNX processing failed: {str(e)}", | |
metrics | |
) | |
if quantization_type != "None" and quantized_model_path: | |
status.update({"progress": 0.9, "current_step": "Pushing Quantized Model"}) | |
status_updates.append("Pushing quantized model to Hub") | |
result, push_message = push_to_hub( | |
local_path=quantized_model_path, | |
repo_name=f"{repo_name}-optimized", | |
hf_token=hf_token, | |
tags=["quantized", task, quantization_type], | |
) | |
status_updates.append(push_message) | |
status.update({"progress": 1.0, "status": "Complete", "current_step": "Completed"}) | |
cleanup_message = resource_manager.cleanup_temp_files() | |
status_updates.append(cleanup_message) | |
return ( | |
status, | |
"\n".join(status_updates), | |
metrics | |
) | |
except Exception as e: | |
logger.error(f"Error during processing: {str(e)}", exc_info=True) | |
return ( | |
{"status": "Error", "progress": 0, "current_step": "Process Failed"}, | |
f"An error occurred: {str(e)}", | |
metrics | |
) | |
# Gradio Interface | |
with gr.Blocks(theme=gr.themes.Soft()) as app: | |
gr.Markdown(""" | |
# 🤗 Model Conversion Hub | |
Convert and optimize your Hugging Face models with quantization and ONNX support. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
model_name = gr.Textbox(label="Model Name", placeholder="e.g., bert-base-uncased") | |
task = gr.Dropdown(choices=list(TASK_CONFIGS.keys()), label="Task", value="text_classification") | |
with gr.Group(): | |
gr.Markdown("### Quantization Settings") | |
quantization_type = gr.Dropdown(choices=["None", "4-bit", "8-bit", "16-bit-float"], label="Quantization Type", value="None") | |
test_text = gr.Textbox(label="Test Text", placeholder="Enter text for model evaluation", lines=3, visible=False) | |
with gr.Group(): | |
gr.Markdown("### ONNX Settings") | |
enable_onnx = gr.Checkbox(label="Enable ONNX Conversion") | |
with gr.Group(visible=False) as onnx_group: | |
onnx_quantization = gr.Dropdown(choices=["None", "8-bit", "16-bit-int", "16-bit-float"], label="ONNX Quantization", value="None") | |
with gr.Group(): | |
gr.Markdown("### HuggingFace Settings") | |
hf_token = gr.Textbox(label="HuggingFace Token (Required)", type="password") | |
repo_name = gr.Textbox(label="Repository Name") | |
with gr.Column(scale=1): | |
status_output = gr.JSON(label="Status", value={"status": "Ready", "progress": 0, "current_step": "Waiting"}) | |
message_output = gr.Markdown(label="Progress Messages") | |
gr.Markdown("### Metrics") | |
with gr.Group(): | |
metrics_output = gr.JSON( | |
value={ | |
"model_sizes": {"original": 0.0, "quantized": 0.0}, | |
"inference_times": {"original": 0.0, "quantized": 0.0}, | |
"comparison_metrics": {} | |
}, | |
show_label=True | |
) | |
memory_info = gr.JSON(label="Resource Usage") | |
convert_btn = gr.Button("🚀 Start Conversion", variant="primary") | |
with gr.Accordion("ℹ️ Help", open=False): | |
gr.Markdown(""" | |
### Quick Guide | |
1. Enter your model name and HuggingFace token. | |
2. Select the appropriate task. | |
3. Choose optimization options. | |
4. Click Start Conversion. | |
### Tips | |
- Ensure sufficient system resources. | |
- Use test text to validate conversions. | |
""") | |
def update_memory_info(): | |
resource_manager = ResourceManager() | |
return resource_manager.get_memory_info() | |
quantization_type.change(lambda x: gr.update(visible=x != "None"), inputs=[quantization_type], outputs=[test_text]) | |
task.change(lambda x: gr.update(value=TASK_CONFIGS[x]["example_text"]), inputs=[task], outputs=[test_text]) | |
enable_onnx.change(lambda x: gr.update(visible=x), inputs=[enable_onnx], outputs=[onnx_group]) | |
convert_btn.click( | |
process_model, | |
inputs=[model_name, task, quantization_type, enable_onnx, onnx_quantization, hf_token, repo_name, test_text], | |
outputs=[status_output, message_output, metrics_output] | |
) | |
app.load(update_memory_info, outputs=[memory_info], every=30) | |
if __name__ == "__main__": | |
app.launch(server_name="0.0.0.0", server_port=7860, share=True, debug=True) | |