Spaces:
Sleeping
Sleeping
| """ | |
| CPU-safe model revision selector. | |
| This script finds a model revision that doesn't hard-require flash_attn | |
| for CPU-only environments. | |
| """ | |
| import os | |
| import re | |
| import sys | |
| import logging | |
| from pathlib import Path | |
| from typing import Optional, List | |
| from huggingface_hub import HfApi, hf_hub_download | |
| from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Configuration | |
| MODEL_ID = os.getenv("HF_MODEL_ID", "microsoft/Phi-3.5-MoE-instruct") | |
| TARGET_FILE = "modeling_phimoe.py" | |
| ENV_FILE = Path(".env") | |
| MAX_COMMITS_TO_CHECK = 50 | |
| class RevisionSelector: | |
| """Selects CPU-safe model revisions.""" | |
| def __init__(self, model_id: str = MODEL_ID): | |
| self.model_id = model_id | |
| self.api = HfApi() | |
| def is_cpu_safe_revision(self, revision: str) -> bool: | |
| """Check if a revision is safe for CPU use (no hard flash_attn import).""" | |
| try: | |
| # Download the modeling file for this revision | |
| file_path = hf_hub_download( | |
| repo_id=self.model_id, | |
| filename=TARGET_FILE, | |
| revision=revision, | |
| repo_type="model", | |
| cache_dir=".cache" | |
| ) | |
| # Read and analyze the file | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| code = f.read() | |
| # Check for hard flash_attn imports at module level | |
| flash_attn_patterns = [ | |
| r'^\s*import\s+flash_attn', | |
| r'^\s*from\s+flash_attn', | |
| r'^\s*import\s+.*flash_attn', | |
| r'^\s*from\s+.*flash_attn' | |
| ] | |
| for pattern in flash_attn_patterns: | |
| if re.search(pattern, code, flags=re.MULTILINE): | |
| logger.debug(f"Revision {revision} has hard flash_attn import") | |
| return False | |
| logger.debug(f"Revision {revision} appears CPU-safe") | |
| return True | |
| except Exception as e: | |
| logger.warning(f"Could not check revision {revision}: {e}") | |
| return False | |
| def get_recent_commits(self, max_commits: int = MAX_COMMITS_TO_CHECK) -> List[str]: | |
| """Get list of recent commit SHAs.""" | |
| try: | |
| commits = list(self.api.list_repo_commits( | |
| repo_id=self.model_id, | |
| repo_type="model" | |
| )) | |
| # Limit to max_commits and extract SHAs | |
| commit_shas = [c.commit_id for c in commits[:max_commits]] | |
| logger.info(f"Found {len(commit_shas)} recent commits to check") | |
| return commit_shas | |
| except Exception as e: | |
| logger.error(f"Failed to get commits: {e}") | |
| return [] | |
| def find_cpu_safe_revision(self) -> Optional[str]: | |
| """Find the most recent CPU-safe revision.""" | |
| logger.info(f"Searching for CPU-safe revision of {self.model_id}") | |
| commits = self.get_recent_commits() | |
| if not commits: | |
| logger.error("No commits found") | |
| return None | |
| for i, commit_sha in enumerate(commits): | |
| logger.info(f"Checking commit {i+1}/{len(commits)}: {commit_sha[:8]}...") | |
| if self.is_cpu_safe_revision(commit_sha): | |
| logger.info(f"β Found CPU-safe revision: {commit_sha}") | |
| return commit_sha | |
| logger.error("β No CPU-safe revision found in recent commits") | |
| return None | |
| def save_revision_to_env(self, revision: str) -> None: | |
| """Save the selected revision to .env file.""" | |
| try: | |
| # Read existing .env content | |
| env_content = "" | |
| if ENV_FILE.exists(): | |
| env_content = ENV_FILE.read_text() | |
| # Remove any existing HF_REVISION line | |
| lines = env_content.split('\n') | |
| lines = [line for line in lines if not line.startswith('HF_REVISION=')] | |
| # Add new revision | |
| lines.append(f'HF_REVISION={revision}') | |
| # Write back to file | |
| ENV_FILE.write_text('\n'.join(lines)) | |
| logger.info(f"β Saved revision {revision} to {ENV_FILE}") | |
| except Exception as e: | |
| logger.error(f"Failed to save revision to .env: {e}") | |
| raise | |
| def main(): | |
| """Main function to select and save CPU-safe revision.""" | |
| # Check if we're on CPU and don't already have a revision set | |
| import torch | |
| if torch.cuda.is_available(): | |
| logger.info("GPU detected - no need to select CPU-safe revision") | |
| return 0 | |
| existing_revision = os.getenv("HF_REVISION") | |
| if existing_revision: | |
| logger.info(f"HF_REVISION already set to: {existing_revision}") | |
| return 0 | |
| logger.info("CPU-only environment detected - selecting CPU-safe revision") | |
| try: | |
| selector = RevisionSelector() | |
| revision = selector.find_cpu_safe_revision() | |
| if revision: | |
| selector.save_revision_to_env(revision) | |
| logger.info(f"π Successfully selected CPU-safe revision: {revision}") | |
| return 0 | |
| else: | |
| logger.error("β Could not find CPU-safe revision") | |
| logger.error("Consider using a different model or enabling GPU") | |
| return 1 | |
| except Exception as e: | |
| logger.error(f"β Error selecting revision: {e}") | |
| return 1 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |