malm-165m / inference.py
codelion's picture
Upload inference.py with huggingface_hub
8459974 verified
#!/usr/bin/env python3
"""MALM Inference Script - Run directly from Hugging Face.
Usage:
# Install dependencies
pip install mlx huggingface_hub
# Download and run
huggingface-cli download codelion/malm-165m --local-dir ./malm-165m
python malm-165m/inference.py --query "function that sorts a list"
"""
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import json
import argparse
from pathlib import Path
from typing import List, Dict, Tuple
import re
class MALM(nn.Module):
"""Memory-Augmented Language Model."""
def __init__(
self,
vocab_size: int,
d_model: int = 768,
n_heads: int = 12,
n_layers: int = 12,
n_query_layers: int = 4,
max_seq_len: int = 128,
dropout: float = 0.0,
):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.n_query_layers = n_query_layers
self.max_seq_len = max_seq_len
# Embeddings
self.embed = nn.Embedding(vocab_size, d_model)
self.pos_embed = nn.Embedding(max_seq_len, d_model)
self.embed_dropout = nn.Dropout(dropout)
# Query encoder
self.query_layers = [
nn.TransformerEncoderLayer(d_model, n_heads, d_model * 4)
for _ in range(n_query_layers)
]
self.query_ln = nn.LayerNorm(d_model)
self.query_proj = nn.Linear(d_model, d_model)
# Value encoder
self.value_layers = [
nn.TransformerEncoderLayer(d_model, n_heads, d_model * 4)
for _ in range(n_query_layers)
]
self.value_ln = nn.LayerNorm(d_model)
self.value_proj = nn.Linear(d_model, d_model)
# Decoder layers
self.decoder_layers = [
nn.TransformerEncoderLayer(d_model, n_heads, d_model * 4)
for _ in range(n_layers)
]
self.decoder_ln = nn.LayerNorm(d_model)
# Output
self.output = nn.Linear(d_model, vocab_size)
# Temperature for retrieval
self.log_temp = mx.array([0.0])
def encode_query(self, query_ids: mx.array) -> mx.array:
"""Encode query to single embedding."""
B, L = query_ids.shape
h = self.embed(query_ids)
pos = mx.arange(min(L, self.max_seq_len))
h = h + self.pos_embed(pos)
h = self.embed_dropout(h)
for layer in self.query_layers:
h = layer(h, None)
h = self.query_ln(h)
mask = (query_ids != 0).astype(mx.float32)[:, :, None]
h = h * mask
query_emb = mx.sum(h, axis=1) / (mx.sum(mask, axis=1) + 1e-8)
return self.query_proj(query_emb)
def encode_value(self, value_ids: mx.array) -> mx.array:
"""Encode value to single embedding."""
B, L = value_ids.shape
h = self.embed(value_ids)
pos = mx.arange(min(L, self.max_seq_len))
h = h + self.pos_embed(pos)
for layer in self.value_layers:
h = layer(h, None)
h = self.value_ln(h)
mask = (value_ids != 0).astype(mx.float32)[:, :, None]
h = h * mask
val_emb = mx.sum(h, axis=1) / (mx.sum(mask, axis=1) + 1e-8)
return self.value_proj(val_emb)
def retrieve(
self,
query_emb: mx.array,
key_emb: mx.array,
val_emb: mx.array,
) -> Tuple[mx.array, mx.array, mx.array]:
"""Retrieve from memory."""
scale = self.d_model ** -0.5
temp = mx.exp(self.log_temp) + 0.1
scores = (query_emb @ key_emb.T) * scale / temp
attn = mx.softmax(scores, axis=-1)
retrieved = attn @ val_emb
return retrieved, attn, scores
class Tokenizer:
"""Simple tokenizer for MALM."""
def __init__(self, tokenizer_dict: Dict):
self.token_to_id = tokenizer_dict.get("token_to_id", {})
self.id_to_token = {int(v): k for k, v in self.token_to_id.items()}
self.special = {"<PAD>": 0, "<UNK>": 1, "<BOS>": 2, "<EOS>": 3}
def encode(self, text: str) -> List[int]:
"""Tokenize text."""
tokens = re.findall(r"[a-zA-Z_][a-zA-Z0-9_]*|[0-9]+|[^\s]", text.lower())
return [self.token_to_id.get(t, self.special.get("<UNK>", 1)) for t in tokens]
def decode(self, ids: List[int]) -> str:
"""Decode token IDs to text."""
tokens = [self.id_to_token.get(i, "<UNK>") for i in ids]
return " ".join(tokens)
def load_model(model_dir: Path):
"""Load MALM model from directory."""
import mlx.utils as mlx_utils
# Load config
with open(model_dir / "config.json") as f:
config = json.load(f)
# Create model
model = MALM(
vocab_size=config["vocab_size"],
d_model=config["d_model"],
n_heads=config["n_heads"],
n_layers=config["n_layers"],
n_query_layers=config["n_query_layers"],
max_seq_len=config["max_seq_len"],
)
# Load weights and convert to mlx arrays
weights = dict(np.load(model_dir / "model.npz"))
weights = {k: mx.array(v) for k, v in weights.items()}
# Unflatten and load
params = mlx_utils.tree_unflatten(list(weights.items()))
model.update(params)
mx.eval(model.parameters())
# Load tokenizer
with open(model_dir / "tokenizer.json") as f:
tokenizer_dict = json.load(f)
tokenizer = Tokenizer(tokenizer_dict)
# Load functions
with open(model_dir / "functions.json") as f:
functions = json.load(f)
return model, tokenizer, functions, config
def search_functions(
model: MALM,
tokenizer: Tokenizer,
functions: List[Dict],
query: str,
top_k: int = 5,
) -> List[Tuple[str, str, float]]:
"""Search for functions matching a query.
Uses the function name as key and signature+docstring as value for retrieval.
"""
# Encode query
query_ids = tokenizer.encode(query)
if not query_ids:
query_ids = [1] # <UNK>
query_ids = mx.array([query_ids])
# Encode all function keys and values
key_tokens = []
value_tokens = []
max_val_len = 64
for func in functions:
name = func["name"]
# Use signature + docstring as the "value" to search over
sig = func.get("signature", name)
doc = func.get("docstring", "")
value_text = f"{sig} {doc}"
key_id = tokenizer.token_to_id.get(name.lower(), 1)
key_tokens.append(key_id)
val_ids = tokenizer.encode(value_text)[:max_val_len]
val_ids = val_ids + [0] * (max_val_len - len(val_ids))
value_tokens.append(val_ids)
key_tokens = mx.array(key_tokens)
value_tokens = mx.array(value_tokens)
# Encode memory
key_emb = model.embed(key_tokens)
val_emb = model.encode_value(value_tokens)
# Get query embedding and compute similarity
query_emb = model.encode_query(query_ids)
_, attn, scores = model.retrieve(query_emb, key_emb, val_emb)
mx.eval(scores)
# Get top-k
scores_np = np.array(scores[0])
top_indices = np.argsort(scores_np)[::-1][:top_k]
results = []
for idx in top_indices:
func = functions[idx]
score = float(scores_np[idx])
sig = func.get("signature", func["name"])
doc = func.get("docstring", "")
results.append((func["name"], sig, doc, score))
return results
def main():
parser = argparse.ArgumentParser(description="MALM Inference - Semantic Code Search")
parser.add_argument("--query", type=str, required=True, help="Natural language query")
parser.add_argument("--top-k", type=int, default=5, help="Number of results")
parser.add_argument("--model-dir", type=str, default=None, help="Model directory")
args = parser.parse_args()
# Determine model directory
if args.model_dir:
model_dir = Path(args.model_dir)
else:
model_dir = Path(__file__).parent
print(f"Loading model from {model_dir}...")
model, tokenizer, functions, config = load_model(model_dir)
print(f"Loaded {len(functions)} functions, {config['num_parameters']:,} parameters")
# Search
print(f"\nQuery: {args.query}")
print("-" * 60)
results = search_functions(model, tokenizer, functions, args.query, args.top_k)
for i, (name, signature, docstring, score) in enumerate(results, 1):
print(f"\n{i}. {name} (score: {score:.4f})")
print(f" Signature: {signature}")
if docstring:
print(f" Docstring: {docstring[:100]}{'...' if len(docstring) > 100 else ''}")
if __name__ == "__main__":
main()