| | |
| | """ |
| | Model Download and Setup for FinEE v2.0 |
| | ======================================== |
| | |
| | Downloads and prepares base models for fine-tuning: |
| | - Llama 3.1 8B Instruct (Primary) |
| | - Qwen2.5 7B Instruct (Backup) |
| | |
| | Supports: |
| | - MLX format for Apple Silicon |
| | - PyTorch/Transformers format |
| | - GGUF for llama.cpp |
| | """ |
| |
|
| | import argparse |
| | import os |
| | import subprocess |
| | import sys |
| | from pathlib import Path |
| |
|
| |
|
| | MODELS = { |
| | "llama-3.1-8b": { |
| | "hf_name": "meta-llama/Llama-3.1-8B-Instruct", |
| | "mlx_name": "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", |
| | "gguf_name": "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", |
| | "description": "Llama 3.1 8B Instruct - Best instruction-following", |
| | "size": "8B", |
| | "context": "128K", |
| | }, |
| | "qwen2.5-7b": { |
| | "hf_name": "Qwen/Qwen2.5-7B-Instruct", |
| | "mlx_name": "mlx-community/Qwen2.5-7B-Instruct-4bit", |
| | "gguf_name": "Qwen/Qwen2.5-7B-Instruct-GGUF", |
| | "description": "Qwen 2.5 7B - Excellent multilingual support", |
| | "size": "7B", |
| | "context": "128K", |
| | }, |
| | "mistral-7b": { |
| | "hf_name": "mistralai/Mistral-7B-Instruct-v0.3", |
| | "mlx_name": "mlx-community/Mistral-7B-Instruct-v0.3-4bit", |
| | "gguf_name": "bartowski/Mistral-7B-Instruct-v0.3-GGUF", |
| | "description": "Mistral 7B - Fast and efficient", |
| | "size": "7B", |
| | "context": "32K", |
| | }, |
| | "phi-3-medium": { |
| | "hf_name": "microsoft/Phi-3-medium-128k-instruct", |
| | "mlx_name": "mlx-community/Phi-3-medium-128k-instruct-4bit", |
| | "description": "Phi-3 Medium - Compact but powerful", |
| | "size": "14B", |
| | "context": "128K", |
| | }, |
| | } |
| |
|
| |
|
| | def download_mlx_model(model_key: str, output_dir: Path): |
| | """Download model in MLX format.""" |
| | model = MODELS[model_key] |
| | mlx_name = model.get("mlx_name") |
| | |
| | if not mlx_name: |
| | print(f"β No MLX version available for {model_key}") |
| | return False |
| | |
| | print(f"\nπ₯ Downloading {model_key} (MLX format)...") |
| | print(f" From: {mlx_name}") |
| | |
| | output_path = output_dir / model_key / "mlx" |
| | output_path.mkdir(parents=True, exist_ok=True) |
| | |
| | try: |
| | from huggingface_hub import snapshot_download |
| | |
| | snapshot_download( |
| | repo_id=mlx_name, |
| | local_dir=str(output_path), |
| | local_dir_use_symlinks=False, |
| | ) |
| | |
| | print(f"β
Downloaded to: {output_path}") |
| | return True |
| | |
| | except Exception as e: |
| | print(f"β Download failed: {e}") |
| | return False |
| |
|
| |
|
| | def download_hf_model(model_key: str, output_dir: Path): |
| | """Download model in HuggingFace format.""" |
| | model = MODELS[model_key] |
| | hf_name = model["hf_name"] |
| | |
| | print(f"\nπ₯ Downloading {model_key} (HuggingFace format)...") |
| | print(f" From: {hf_name}") |
| | |
| | output_path = output_dir / model_key / "hf" |
| | output_path.mkdir(parents=True, exist_ok=True) |
| | |
| | try: |
| | from huggingface_hub import snapshot_download |
| | |
| | snapshot_download( |
| | repo_id=hf_name, |
| | local_dir=str(output_path), |
| | local_dir_use_symlinks=False, |
| | ignore_patterns=["*.bin", "*.h5"], |
| | ) |
| | |
| | print(f"β
Downloaded to: {output_path}") |
| | return True |
| | |
| | except Exception as e: |
| | print(f"β Download failed: {e}") |
| | print(" Note: Some models require HuggingFace login") |
| | print(" Run: huggingface-cli login") |
| | return False |
| |
|
| |
|
| | def download_gguf_model(model_key: str, output_dir: Path, quant: str = "Q4_K_M"): |
| | """Download GGUF quantized model.""" |
| | model = MODELS[model_key] |
| | gguf_name = model.get("gguf_name") |
| | |
| | if not gguf_name: |
| | print(f"β No GGUF version available for {model_key}") |
| | return False |
| | |
| | print(f"\nπ₯ Downloading {model_key} (GGUF {quant} format)...") |
| | print(f" From: {gguf_name}") |
| | |
| | output_path = output_dir / model_key / "gguf" |
| | output_path.mkdir(parents=True, exist_ok=True) |
| | |
| | try: |
| | from huggingface_hub import hf_hub_download |
| | |
| | |
| | filename = f"*{quant}*.gguf" |
| | |
| | hf_hub_download( |
| | repo_id=gguf_name, |
| | filename=filename, |
| | local_dir=str(output_path), |
| | local_dir_use_symlinks=False, |
| | ) |
| | |
| | print(f"β
Downloaded to: {output_path}") |
| | return True |
| | |
| | except Exception as e: |
| | print(f"β Download failed: {e}") |
| | return False |
| |
|
| |
|
| | def convert_to_mlx(model_path: Path, output_path: Path, quantize: bool = True): |
| | """Convert HuggingFace model to MLX format.""" |
| | print(f"\nπ Converting to MLX format...") |
| | |
| | cmd = [ |
| | sys.executable, "-m", "mlx_lm.convert", |
| | "--hf-path", str(model_path), |
| | "--mlx-path", str(output_path), |
| | ] |
| | |
| | if quantize: |
| | cmd.extend(["--quantize", "--q-bits", "4"]) |
| | |
| | try: |
| | subprocess.run(cmd, check=True) |
| | print(f"β
Converted to: {output_path}") |
| | return True |
| | except subprocess.CalledProcessError as e: |
| | print(f"β Conversion failed: {e}") |
| | return False |
| |
|
| |
|
| | def verify_model(model_path: Path, backend: str = "mlx"): |
| | """Verify model can be loaded.""" |
| | print(f"\nπ Verifying model at {model_path}...") |
| | |
| | if backend == "mlx": |
| | try: |
| | from mlx_lm import load, generate |
| | |
| | model, tokenizer = load(str(model_path)) |
| | |
| | |
| | output = generate(model, tokenizer, "Hello", max_tokens=10) |
| | print(f"β
Model loaded successfully!") |
| | print(f" Test output: {output[:50]}...") |
| | return True |
| | except Exception as e: |
| | print(f"β Verification failed: {e}") |
| | return False |
| | |
| | elif backend == "transformers": |
| | try: |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | |
| | tokenizer = AutoTokenizer.from_pretrained(str(model_path)) |
| | model = AutoModelForCausalLM.from_pretrained(str(model_path)) |
| | |
| | print(f"β
Model loaded successfully!") |
| | return True |
| | except Exception as e: |
| | print(f"β Verification failed: {e}") |
| | return False |
| |
|
| |
|
| | def list_models(): |
| | """List available models.""" |
| | print("\nπ Available Models:\n") |
| | print(f"{'Model':<20} {'Size':<8} {'Context':<10} {'Description'}") |
| | print("-" * 80) |
| | |
| | for key, model in MODELS.items(): |
| | print(f"{key:<20} {model['size']:<8} {model['context']:<10} {model['description']}") |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="Download and setup base models") |
| | parser.add_argument("action", choices=["download", "convert", "verify", "list"], |
| | help="Action to perform") |
| | parser.add_argument("-m", "--model", choices=list(MODELS.keys()), |
| | default="llama-3.1-8b", help="Model to download") |
| | parser.add_argument("-f", "--format", choices=["mlx", "hf", "gguf", "all"], |
| | default="mlx", help="Model format") |
| | parser.add_argument("-o", "--output", default="models/base", |
| | help="Output directory") |
| | parser.add_argument("-q", "--quant", default="Q4_K_M", |
| | help="GGUF quantization level") |
| | |
| | args = parser.parse_args() |
| | |
| | output_dir = Path(args.output) |
| | |
| | if args.action == "list": |
| | list_models() |
| | return |
| | |
| | if args.action == "download": |
| | if args.format in ["mlx", "all"]: |
| | download_mlx_model(args.model, output_dir) |
| | |
| | if args.format in ["hf", "all"]: |
| | download_hf_model(args.model, output_dir) |
| | |
| | if args.format in ["gguf", "all"]: |
| | download_gguf_model(args.model, output_dir, args.quant) |
| | |
| | elif args.action == "convert": |
| | hf_path = output_dir / args.model / "hf" |
| | mlx_path = output_dir / args.model / "mlx-converted" |
| | convert_to_mlx(hf_path, mlx_path) |
| | |
| | elif args.action == "verify": |
| | model_path = output_dir / args.model |
| | if args.format == "mlx": |
| | model_path = model_path / "mlx" |
| | elif args.format == "hf": |
| | model_path = model_path / "hf" |
| | |
| | verify_model(model_path, args.format) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|