| """
|
| Model Download Script.
|
|
|
| Downloads and caches the Wav2Vec2 model for VoiceAuth API.
|
| """
|
|
|
| import argparse
|
| import os
|
| import sys
|
| from pathlib import Path
|
|
|
|
|
| sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
|
| def download_model(
|
| model_name: str = "facebook/wav2vec2-base",
|
| output_dir: str | None = None,
|
| force: bool = False,
|
| ) -> None:
|
| """
|
| Download and cache the Wav2Vec2 model.
|
|
|
| Args:
|
| model_name: HuggingFace model name or path
|
| output_dir: Optional local directory to save model
|
| force: Force re-download even if cached
|
| """
|
| print("\n" + "=" * 60)
|
| print("VoiceAuth - Model Download")
|
| print("=" * 60 + "\n")
|
| print(f"Model: {model_name}")
|
|
|
| if output_dir:
|
| print(f"Output: {output_dir}")
|
|
|
| print("\nDownloading model components...")
|
| print("-" * 40)
|
|
|
| try:
|
|
|
| from transformers import Wav2Vec2ForSequenceClassification
|
| from transformers import Wav2Vec2Processor
|
|
|
|
|
| print("\n[1/2] Downloading Wav2Vec2Processor...")
|
| processor = Wav2Vec2Processor.from_pretrained(
|
| model_name,
|
| force_download=force,
|
| )
|
| print(" [OK] Processor downloaded")
|
|
|
|
|
| print("\n[2/2] Downloading Wav2Vec2ForSequenceClassification...")
|
| model = Wav2Vec2ForSequenceClassification.from_pretrained(
|
| model_name,
|
| num_labels=2,
|
| label2id={"HUMAN": 0, "AI_GENERATED": 1},
|
| id2label={0: "HUMAN", 1: "AI_GENERATED"},
|
| force_download=force,
|
| )
|
| print(" [OK] Model downloaded")
|
|
|
|
|
| if output_dir:
|
| output_path = Path(output_dir)
|
| output_path.mkdir(parents=True, exist_ok=True)
|
|
|
| print(f"\nSaving to {output_path}...")
|
| processor.save_pretrained(output_path)
|
| model.save_pretrained(output_path)
|
| print("[OK] Model saved locally")
|
|
|
| print("\n" + "=" * 60)
|
| print("Download Complete!")
|
| print("=" * 60)
|
|
|
|
|
| cache_dir = os.environ.get(
|
| "HF_HOME",
|
| os.path.expanduser("~/.cache/huggingface"),
|
| )
|
| print(f"\nCache location: {cache_dir}")
|
|
|
| if output_dir:
|
| print(f"Local copy: {output_dir}")
|
|
|
| print("\nYou can now start the API with:")
|
| print(" uvicorn app.main:app --reload")
|
| print()
|
|
|
| except Exception as e:
|
| print(f"\n[ERROR] Error downloading model: {e}")
|
| sys.exit(1)
|
|
|
|
|
| def main() -> None:
|
| """Main entry point."""
|
| parser = argparse.ArgumentParser(
|
| description="Download Wav2Vec2 model for VoiceAuth API",
|
| formatter_class=argparse.RawDescriptionHelpFormatter,
|
| epilog="""
|
| Examples:
|
| python download_model.py
|
| python download_model.py --model facebook/wav2vec2-large-xlsr-53
|
| python download_model.py --output ./models
|
| python download_model.py --force
|
| """,
|
| )
|
|
|
| parser.add_argument(
|
| "--model",
|
| type=str,
|
| default="facebook/wav2vec2-base",
|
| help="HuggingFace model name (default: facebook/wav2vec2-base)",
|
| )
|
| parser.add_argument(
|
| "--output",
|
| type=str,
|
| default=None,
|
| help="Optional local directory to save model",
|
| )
|
| parser.add_argument(
|
| "--force",
|
| action="store_true",
|
| help="Force re-download even if cached",
|
| )
|
|
|
| args = parser.parse_args()
|
|
|
| download_model(
|
| model_name=args.model,
|
| output_dir=args.output,
|
| force=args.force,
|
| )
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|