Spaces:
Running
Running
| """ | |
| Converts ONNX models to TFLite for offline edge deployment (Android phones in rural areas). | |
| Note: Whisper's encoder and decoder are exported as separate TFLite models and | |
| orchestrated together at inference time. | |
| Requires: onnx-tf, tensorflow (install separately — large dependencies) | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from pathlib import Path | |
| logger = logging.getLogger(__name__) | |
| class TFLiteConverter: | |
| """Converts ONNX Whisper models to TFLite format for edge deployment.""" | |
| def convert( | |
| self, | |
| onnx_encoder_path: str, | |
| onnx_decoder_path: str, | |
| output_dir: str, | |
| quantize: bool = True, | |
| ) -> dict[str, Path]: | |
| """ | |
| Convert encoder and decoder ONNX models to TFLite. | |
| Returns paths to the generated .tflite files. | |
| """ | |
| output_path = Path(output_dir) | |
| output_path.mkdir(parents=True, exist_ok=True) | |
| encoder_tflite = output_path / "encoder.tflite" | |
| decoder_tflite = output_path / "decoder.tflite" | |
| logger.info("Converting encoder ONNX → TFLite...") | |
| self._onnx_to_tflite(onnx_encoder_path, str(encoder_tflite), quantize=quantize) | |
| logger.info("Converting decoder ONNX → TFLite...") | |
| self._onnx_to_tflite(onnx_decoder_path, str(decoder_tflite), quantize=quantize) | |
| return {"encoder": encoder_tflite, "decoder": decoder_tflite} | |
| def _onnx_to_tflite(self, onnx_path: str, output_path: str, quantize: bool) -> None: | |
| """Convert a single ONNX model to TFLite via onnx-tf + tensorflow.""" | |
| try: | |
| import onnx | |
| import onnx_tf | |
| import tensorflow as tf | |
| except ImportError as e: | |
| raise ImportError( | |
| "TFLite conversion requires onnx-tf and tensorflow. " | |
| "Install with: pip install onnx-tf tensorflow" | |
| ) from e | |
| import tempfile | |
| # Step 1: ONNX → TensorFlow SavedModel | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| onnx_model = onnx.load(onnx_path) | |
| tf_rep = onnx_tf.backend.prepare(onnx_model) | |
| tf_rep.export_graph(tmp_dir) | |
| # Step 2: TF SavedModel → TFLite | |
| converter = tf.lite.TFLiteConverter.from_saved_model(tmp_dir) | |
| if quantize: | |
| converter.optimizations = [tf.lite.Optimize.DEFAULT] | |
| tflite_model = converter.convert() | |
| with open(output_path, "wb") as f: | |
| f.write(tflite_model) | |
| size_mb = Path(output_path).stat().st_size / 1e6 | |
| logger.info("TFLite model saved: %s (%.1f MB)", output_path, size_mb) | |