reFlow / model_loader.py
”reuAC“
Add ModelScope as fallback download source
3726763
"""Singleton model loader — downloads reflow-1-big from HF Hub once."""
import os
import sys
import torch
import tiktoken
# Try to import from models/ directory first, then from root
try:
from models.reflow import GPT, GPTConfig
except ImportError:
from reflow import GPT, GPTConfig
_model = None
_enc = None
_device = None
_W_v2s = None
_signal_basis = None
# Local paths to check before downloading (relative to this file's directory)
_LOCAL_CKPT_CANDIDATES = [
"model/ckpt.pt",
]
def _find_local_ckpt():
"""Search for a local checkpoint file relative to the project root."""
base_dir = os.path.dirname(os.path.abspath(__file__))
for rel in _LOCAL_CKPT_CANDIDATES:
path = os.path.join(base_dir, rel)
if os.path.isfile(path):
return path
return None
def get_model():
"""Return (model, enc, device). First call downloads & loads the checkpoint."""
global _model, _enc, _device, _W_v2s, _signal_basis
if _model is not None:
return _model, _enc, _device
_device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[model_loader] device = {_device}")
# 1. Try local checkpoint first
ckpt_path = _find_local_ckpt()
if ckpt_path:
print(f"[model_loader] Found local checkpoint: {ckpt_path}")
else:
# 2. Fall back to download from HF Hub or ModelScope
print("[model_loader] No local checkpoint found, downloading from cloud ...")
try:
from huggingface_hub import hf_hub_download
print("[model_loader] Trying Hugging Face Hub ...")
ckpt_path = hf_hub_download(
repo_id="reuAC/reFlow",
filename="out/reflow-1-big/ckpt.pt",
)
print(f"[model_loader] Downloaded from Hugging Face: {ckpt_path}")
except Exception as e:
print(f"[model_loader] Hugging Face download failed: {e}")
print("[model_loader] Trying ModelScope ...")
from modelscope.hub.file_download import model_file_download
ckpt_path = model_file_download(
model_id="recuAC/reFlow",
file_path="out/reflow-1-big/ckpt.pt",
)
print(f"[model_loader] Downloaded from ModelScope: {ckpt_path}")
checkpoint = torch.load(ckpt_path, map_location=_device, weights_only=False)
# Build model from saved config
model_args = checkpoint["model_args"]
config = GPTConfig(**model_args)
_model = GPT(config)
# Strip _orig_mod. prefix (torch.compile artifact)
state_dict = checkpoint["model"]
for k in list(state_dict.keys()):
if k.startswith("_orig_mod."):
state_dict[k[len("_orig_mod."):]] = state_dict.pop(k)
_model.load_state_dict(state_dict)
_model.eval().to(_device)
# Cache frequently-used tensors
_W_v2s = _model.transformer.wte.vocab_to_signals.weight.data
_signal_basis = _model.transformer.wte.signal_basis.data
_enc = tiktoken.get_encoding("gpt2")
print("[model_loader] Model ready.")
return _model, _enc, _device
def get_cached_tensors():
"""Return (W_v2s, signal_basis) — call get_model() first."""
if _W_v2s is None:
get_model()
return _W_v2s, _signal_basis