API & Inference Usage
This guide covers how to load the MiniLM 1.58-bit base model and dynamically snap on custom LoRAs for inference.
Python Inference (PyTorch)
Because MiniLM uses custom ternary BitLinear layers, it cannot be loaded via the standard transformers AutoModel pipeline. You must use the provided model.py and lora.py scripts.
1. Loading the Base Model
import torch
from transformers import AutoTokenizer
from model import BitGPT
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct")
# Initialize the 12-Layer Tied Architecture
model = BitGPT(vocab_size=len(tokenizer), embed_dim=256, num_layers=12, num_heads=4, tie_weights=True).to(device)
# Load the frozen 1.58-bit Base Weights
model.load_state_dict(torch.load("minilm_base.pt", map_location=device))
model.eval()
2. Injecting a "Side-Car" LoRA
If you want to run a specific task (like Smart Home JSON extraction), you must wrap the Linear layers with the custom BitLoraLinear adapter.
from lora import inject_lora
# Wrap the model's layers with LoRA adapters
model = inject_lora(model, r=8, lora_alpha=16).to(device)
# Snap on the custom 1MB weights (strict=False ensures we only overwrite the new LoRA parameters)
model.load_state_dict(torch.load("lora_smarthome.pt", map_location=device), strict=False)
model.eval()
3. Generation Loop
To generate text, format your prompt using ChatML standard tags:
prompt = "Uh, it's freezing in here, can you turn up the heat in the living room?"
chatml_text = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
input_ids = tokenizer.encode(chatml_text, return_tensors="pt").to(device)
max_new_tokens = 60
with torch.no_grad():
for _ in range(max_new_tokens):
logits = model(input_ids)
next_token_logits = logits[:, -1, :]
# Greedy decoding
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
input_ids = torch.cat([input_ids, next_token], dim=-1)
# Stop condition (2 is im_end in ChatML)
if next_token.item() == tokenizer.eos_token_id or next_token.item() == 2:
break
output_text = tokenizer.decode(input_ids[0])
final_output = output_text.split("<|im_start|>assistant\n")[-1].replace("<|im_end|>", "").strip()
print(final_output)
# Output: {"device": "thermostat", "action": "increase_temp", "room": "living_room"}