import sys import os sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))) import gradio as gr import logging from typing import Tuple, Dict, Any from src.utilities.resources import ResourceManager from src.utilities.push_to_hub import push_to_hub from src.optimizations.onnx_conversion import convert_to_onnx from src.optimizations.quantize import quantize_onnx_model from src.handlers import get_model_handler, TASK_CONFIGS logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) import json def process_model( model_name: str, task: str, quantization_type: str, enable_onnx: bool, onnx_quantization: str, hf_token: str, repo_name: str, test_text: str ) -> Tuple[Dict[str, Any], str, Dict[str, Any]]: try: resource_manager = ResourceManager() status_updates = [] status = { "status": "Processing", "progress": 0, "current_step": "Initializing", } metrics = {} if not model_name or not hf_token or not repo_name: return ( {"status": "Error", "progress": 0, "current_step": "Validation Failed"}, "Model name, HuggingFace token, and repository name are required.", metrics ) status["progress"] = 0.2 status["current_step"] = "Initialization" status_updates.append("Initialization complete") quantized_model_path = None if quantization_type != "None": status.update({"progress": 0.4, "current_step": "Quantization"}) status_updates.append(f"Applying {quantization_type} quantization") if not test_text: test_text = TASK_CONFIGS[task]["example_text"] try: handler = get_model_handler(task, model_name, quantization_type, test_text) quantized_model = handler.compare() metrics = handler.get_metrics() metrics = json.loads(json.dumps(metrics)) quantized_model_path = str(resource_manager.temp_dirs["quantized"] / "model") quantized_model.save_pretrained(quantized_model_path) status_updates.append("Quantization completed successfully") except Exception as e: logger.error(f"Quantization error: {str(e)}", exc_info=True) return ( {"status": "Error", "progress": 0.4, "current_step": "Quantization Failed"}, f"Quantization failed: {str(e)}", metrics ) if enable_onnx: status.update({"progress": 0.6, "current_step": "ONNX Conversion"}) status_updates.append("Converting to ONNX format") try: output_dir = str(resource_manager.temp_dirs["onnx"]) onnx_result = convert_to_onnx(model_name, task, output_dir) if onnx_result is None: return ( {"status": "Error", "progress": 0.6, "current_step": "ONNX Conversion Failed"}, "ONNX conversion failed.", metrics ) if onnx_quantization != "None": status_updates.append(f"Applying {onnx_quantization} ONNX quantization") quantize_onnx_model(output_dir, onnx_quantization) status.update({"progress": 0.8, "current_step": "Pushing ONNX Model"}) status_updates.append("Pushing ONNX model to Hub") result, push_message = push_to_hub( local_path=output_dir, repo_name=f"{repo_name}-optimized", hf_token=hf_token, tags=["onnx", "optimum", task], ) status_updates.append(push_message) except Exception as e: logger.error(f"ONNX error: {str(e)}", exc_info=True) return ( {"status": "Error", "progress": 0.6, "current_step": "ONNX Processing Failed"}, f"ONNX processing failed: {str(e)}", metrics ) if quantization_type != "None" and quantized_model_path: status.update({"progress": 0.9, "current_step": "Pushing Quantized Model"}) status_updates.append("Pushing quantized model to Hub") result, push_message = push_to_hub( local_path=quantized_model_path, repo_name=f"{repo_name}-optimized", hf_token=hf_token, tags=["quantized", task, quantization_type], ) status_updates.append(push_message) status.update({"progress": 1.0, "status": "Complete", "current_step": "Completed"}) cleanup_message = resource_manager.cleanup_temp_files() status_updates.append(cleanup_message) return ( status, "\n".join(status_updates), metrics ) except Exception as e: logger.error(f"Error during processing: {str(e)}", exc_info=True) return ( {"status": "Error", "progress": 0, "current_step": "Process Failed"}, f"An error occurred: {str(e)}", metrics ) # Gradio Interface with gr.Blocks(theme=gr.themes.Soft()) as app: gr.Markdown(""" # 🤗 Model Conversion Hub Convert and optimize your Hugging Face models with quantization and ONNX support. """) with gr.Row(): with gr.Column(scale=2): model_name = gr.Textbox(label="Model Name", placeholder="e.g., bert-base-uncased") task = gr.Dropdown(choices=list(TASK_CONFIGS.keys()), label="Task", value="text_classification") with gr.Group(): gr.Markdown("### Quantization Settings") quantization_type = gr.Dropdown(choices=["None", "4-bit", "8-bit", "16-bit-float"], label="Quantization Type", value="None") test_text = gr.Textbox(label="Test Text", placeholder="Enter text for model evaluation", lines=3, visible=False) with gr.Group(): gr.Markdown("### ONNX Settings") enable_onnx = gr.Checkbox(label="Enable ONNX Conversion") with gr.Group(visible=False) as onnx_group: onnx_quantization = gr.Dropdown(choices=["None", "8-bit", "16-bit-int", "16-bit-float"], label="ONNX Quantization", value="None") with gr.Group(): gr.Markdown("### HuggingFace Settings") hf_token = gr.Textbox(label="HuggingFace Token (Required)", type="password") repo_name = gr.Textbox(label="Repository Name") with gr.Column(scale=1): status_output = gr.JSON(label="Status", value={"status": "Ready", "progress": 0, "current_step": "Waiting"}) message_output = gr.Markdown(label="Progress Messages") gr.Markdown("### Metrics") with gr.Group(): metrics_output = gr.JSON( value={ "model_sizes": {"original": 0.0, "quantized": 0.0}, "inference_times": {"original": 0.0, "quantized": 0.0}, "comparison_metrics": {} }, show_label=True ) memory_info = gr.JSON(label="Resource Usage") convert_btn = gr.Button("🚀 Start Conversion", variant="primary") with gr.Accordion("â„šī¸ Help", open=False): gr.Markdown(""" ### Quick Guide 1. Enter your model name and HuggingFace token. 2. Select the appropriate task. 3. Choose optimization options. 4. Click Start Conversion. ### Tips - Ensure sufficient system resources. - Use test text to validate conversions. """) def update_memory_info(): resource_manager = ResourceManager() return resource_manager.get_memory_info() quantization_type.change(lambda x: gr.update(visible=x != "None"), inputs=[quantization_type], outputs=[test_text]) task.change(lambda x: gr.update(value=TASK_CONFIGS[x]["example_text"]), inputs=[task], outputs=[test_text]) enable_onnx.change(lambda x: gr.update(visible=x), inputs=[enable_onnx], outputs=[onnx_group]) convert_btn.click( process_model, inputs=[model_name, task, quantization_type, enable_onnx, onnx_quantization, hf_token, repo_name, test_text], outputs=[status_output, message_output, metrics_output] ) app.load(update_memory_info, outputs=[memory_info], every=30) if __name__ == "__main__": app.launch(server_name="0.0.0.0", server_port=7860, share=True, debug=True)