bitnet_cpu_assistant / model_manager.py
Vishwas1's picture
Upload model_manager.py
ca64889 verified
import os
import subprocess
import streamlit as st
from huggingface_hub import hf_hub_download
class BitNetManager:
def __init__(self, repo_url="https://github.com/microsoft/BitNet.git"):
self.repo_url = repo_url
self.base_dir = os.path.dirname(os.path.abspath(__file__))
self.bitnet_dir = os.path.join(self.base_dir, "BitNet")
self.build_dir = os.path.join(self.bitnet_dir, "build")
def patch_source(self):
"""Patch the BitNet source code to fix known compilation errors."""
target_file = os.path.join(self.bitnet_dir, "src", "ggml-bitnet-mad.cpp")
if os.path.exists(target_file):
st.info("Patching source code to fix const-qualifier error...")
with open(target_file, "r") as f:
content = f.read()
# Change 'int8_t * y_col = y + col * by;' to 'const int8_t * y_col = y + col * by;'
# This fixes the "cannot initialize a variable of type 'int8_t *' with an rvalue of type 'const int8_t *'" error
old_str = "int8_t * y_col = y + col * by;"
new_str = "const int8_t * y_col = y + col * by;"
if old_str in content:
patched_content = content.replace(old_str, new_str)
with open(target_file, "w") as f:
f.write(patched_content)
st.success("Patch applied to ggml-bitnet-mad.cpp!")
else:
st.warning("Patch target line not found in ggml-bitnet-mad.cpp.")
# Patch setup_env.py to fix potential path or environment issues
setup_script = os.path.join(self.bitnet_dir, "setup_env.py")
if os.path.exists(setup_script):
st.info("Patching setup_env.py to use Python API instead of CLI...")
with open(setup_script, "r") as f:
setup_content = f.read()
# The line we want to replace
old_line = 'run_command(["huggingface-cli", "download", hf_url, "--local-dir", model_dir], log_step="download_model")'
# The replacement using Python API
new_line = 'from huggingface_hub import snapshot_download; snapshot_download(repo_id=hf_url, local_dir=model_dir)'
if old_line in setup_content:
patched_setup = setup_content.replace(old_line, new_line)
with open(setup_script, "w") as f:
f.write(patched_setup)
st.success("Successfully patched setup_env.py with Python API!")
elif 'huggingface-cli' in setup_content:
# Fallback: if they used different quotes or slightly different structure
# We'll try to find any list containing "huggingface-cli" and "download"
import re
# This regex looks for run_command([..."huggingface-cli"..."download"...])
pattern = r'run_command\(\s*\[\s*["\']huggingface-cli["\'],\s*["\']download["\'],[^\]]+\][^)]*\)'
matches = re.findall(pattern, setup_content)
if matches:
patched_setup = re.sub(pattern, new_line, setup_content)
with open(setup_script, "w") as f:
f.write(patched_setup)
st.success("Successfully patched setup_env.py (via regex)!")
else:
st.warning("Could not find the exact download command in setup_env.py to patch.")
pass
def setup_engine(self, model_id="1bitLLM/bitnet_b1_58-3B"):
"""Clone and compile utilizing official setup_env.py with log streaming."""
model_name = model_id.split("/")[-1]
model_path = os.path.join(self.bitnet_dir, "models", model_name, "ggml-model-i2_s.gguf")
binary = self.get_binary_path()
# Check if already compiled AND model exists
if binary and os.path.exists(binary) and os.path.exists(model_path):
st.success(f"BitNet engine and model ({model_name}) are ready!")
return True
if binary and os.path.exists(binary):
st.info(f"Engine binary found, but model weights for {model_name} are missing. Starting setup...")
if not os.path.exists(self.bitnet_dir):
st.info("Cloning BitNet repository...")
subprocess.run(["git", "clone", "--recursive", self.repo_url], check=True)
self.patch_source()
st.info("Running official BitNet setup (setup_env.py)...")
try:
# -u for unbuffered output to see logs in real-time
cmd = ["python", "-u", "setup_env.py", "--hf-repo", "1bitLLM/bitnet_b1_58-3B", "--use-pretuned"]
# Stream the stdout to Streamlit in real-time
process = subprocess.Popen(cmd, cwd=self.bitnet_dir, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
log_container = st.empty()
logs = []
for line in process.stdout:
logs.append(line)
# Keep only the last 15 lines or so for UI clarity
log_container.code("".join(logs[-15:]))
process.wait()
if process.returncode != 0:
st.error(f"Setup failed (Exit {process.returncode})")
# Check for specific logs
logs_dir = os.path.join(self.bitnet_dir, "logs")
comp_log = os.path.join(logs_dir, "compile.log")
down_log = os.path.join(logs_dir, "download_model.log")
if os.path.exists(down_log):
st.info("Download Log (logs/download_model.log):")
with open(down_log, "r") as f:
st.code(f.read()[-3000:])
elif os.path.exists(comp_log):
st.info("Compilation Log (logs/compile.log):")
with open(comp_log, "r") as f:
st.code(f.read()[-3000:])
else:
st.info("Detailed Output:")
st.code("".join(logs)[-2000:])
return False
st.success("Official Setup Completed Successfully!")
return True
except Exception as e:
st.error(f"Execution error durante setup: {e}")
return False
def get_binary_path(self):
"""Locate the bitnet binary based on platform/build structure."""
possible_paths = [
os.path.join(self.bitnet_dir, "build", "bin", "llama-cli"), # Standard location
os.path.join(self.bitnet_dir, "build", "llama-cli"), # Alternate location
os.path.join(self.bitnet_dir, "build", "bitnet"), # Legacy/Custom
os.path.join(self.bitnet_dir, "build", "bin", "bitnet"),
os.path.join(self.bitnet_dir, "build", "Release", "bitnet.exe"), # Windows
os.path.join(self.bitnet_dir, "build", "bin", "Release", "llama-cli.exe"),
os.path.join(self.bitnet_dir, "run_inference.py") # Script fallback
]
for p in possible_paths:
if os.path.exists(p):
return p
return None
def download_model(self, model_id="1bitLLM/bitnet_b1_58-3B", filename="ggml-model-i2_s.gguf"):
"""Locate the model weights. These are generated locally by setup_env.py."""
# setup_env.py downloads weights to models/<model_name>/
# e.g. models/bitnet_b1_58-3B/
model_name = model_id.split("/")[-1]
local_model_path = os.path.join(self.bitnet_dir, "models", model_name, filename)
if os.path.exists(local_model_path):
st.success(f"Found local model: {model_name}")
return local_model_path
st.error(f"Model file not found at {local_model_path}")
st.info("The GGUF model must be generated by the 'Initialize Engine' process. Please run it again to download and convert the weights.")
return None
def run_inference(self, prompt, model_path):
"""Execute the bitnet binary with the provided prompt."""
binary = self.get_binary_path()
if not binary:
st.error("Inference binary not found. Please re-run Initialization.")
return None
# Build the command. bitnet binary usually takes -m and -p
if binary.endswith(".py"):
cmd = ["python", binary, "-m", model_path, "-p", prompt, "-n", "128"]
else:
cmd = [binary, "-m", model_path, "-p", prompt, "-n", "128"]
try:
# We'll return a Popen object so the app can stream the response
# CRITICAL: We must set cwd=self.bitnet_dir so run_inference.py can find build/
process = subprocess.Popen(
cmd,
cwd=self.bitnet_dir,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1
)
return process
except Exception as e:
st.error(f"Inference execution failed: {e}")
return None