|
|
|
|
|
"""MALM Inference Script - Run directly from Hugging Face. |
|
|
|
|
|
Usage: |
|
|
# Install dependencies |
|
|
pip install mlx huggingface_hub |
|
|
|
|
|
# Download and run |
|
|
huggingface-cli download codelion/malm-165m --local-dir ./malm-165m |
|
|
python malm-165m/inference.py --query "function that sorts a list" |
|
|
""" |
|
|
|
|
|
import mlx.core as mx |
|
|
import mlx.nn as nn |
|
|
import numpy as np |
|
|
import json |
|
|
import argparse |
|
|
from pathlib import Path |
|
|
from typing import List, Dict, Tuple |
|
|
import re |
|
|
|
|
|
|
|
|
class MALM(nn.Module): |
|
|
"""Memory-Augmented Language Model.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size: int, |
|
|
d_model: int = 768, |
|
|
n_heads: int = 12, |
|
|
n_layers: int = 12, |
|
|
n_query_layers: int = 4, |
|
|
max_seq_len: int = 128, |
|
|
dropout: float = 0.0, |
|
|
): |
|
|
super().__init__() |
|
|
self.vocab_size = vocab_size |
|
|
self.d_model = d_model |
|
|
self.n_heads = n_heads |
|
|
self.n_layers = n_layers |
|
|
self.n_query_layers = n_query_layers |
|
|
self.max_seq_len = max_seq_len |
|
|
|
|
|
|
|
|
self.embed = nn.Embedding(vocab_size, d_model) |
|
|
self.pos_embed = nn.Embedding(max_seq_len, d_model) |
|
|
self.embed_dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
self.query_layers = [ |
|
|
nn.TransformerEncoderLayer(d_model, n_heads, d_model * 4) |
|
|
for _ in range(n_query_layers) |
|
|
] |
|
|
self.query_ln = nn.LayerNorm(d_model) |
|
|
self.query_proj = nn.Linear(d_model, d_model) |
|
|
|
|
|
|
|
|
self.value_layers = [ |
|
|
nn.TransformerEncoderLayer(d_model, n_heads, d_model * 4) |
|
|
for _ in range(n_query_layers) |
|
|
] |
|
|
self.value_ln = nn.LayerNorm(d_model) |
|
|
self.value_proj = nn.Linear(d_model, d_model) |
|
|
|
|
|
|
|
|
self.decoder_layers = [ |
|
|
nn.TransformerEncoderLayer(d_model, n_heads, d_model * 4) |
|
|
for _ in range(n_layers) |
|
|
] |
|
|
self.decoder_ln = nn.LayerNorm(d_model) |
|
|
|
|
|
|
|
|
self.output = nn.Linear(d_model, vocab_size) |
|
|
|
|
|
|
|
|
self.log_temp = mx.array([0.0]) |
|
|
|
|
|
def encode_query(self, query_ids: mx.array) -> mx.array: |
|
|
"""Encode query to single embedding.""" |
|
|
B, L = query_ids.shape |
|
|
|
|
|
h = self.embed(query_ids) |
|
|
pos = mx.arange(min(L, self.max_seq_len)) |
|
|
h = h + self.pos_embed(pos) |
|
|
h = self.embed_dropout(h) |
|
|
|
|
|
for layer in self.query_layers: |
|
|
h = layer(h, None) |
|
|
|
|
|
h = self.query_ln(h) |
|
|
|
|
|
mask = (query_ids != 0).astype(mx.float32)[:, :, None] |
|
|
h = h * mask |
|
|
query_emb = mx.sum(h, axis=1) / (mx.sum(mask, axis=1) + 1e-8) |
|
|
|
|
|
return self.query_proj(query_emb) |
|
|
|
|
|
def encode_value(self, value_ids: mx.array) -> mx.array: |
|
|
"""Encode value to single embedding.""" |
|
|
B, L = value_ids.shape |
|
|
|
|
|
h = self.embed(value_ids) |
|
|
pos = mx.arange(min(L, self.max_seq_len)) |
|
|
h = h + self.pos_embed(pos) |
|
|
|
|
|
for layer in self.value_layers: |
|
|
h = layer(h, None) |
|
|
|
|
|
h = self.value_ln(h) |
|
|
|
|
|
mask = (value_ids != 0).astype(mx.float32)[:, :, None] |
|
|
h = h * mask |
|
|
val_emb = mx.sum(h, axis=1) / (mx.sum(mask, axis=1) + 1e-8) |
|
|
|
|
|
return self.value_proj(val_emb) |
|
|
|
|
|
def retrieve( |
|
|
self, |
|
|
query_emb: mx.array, |
|
|
key_emb: mx.array, |
|
|
val_emb: mx.array, |
|
|
) -> Tuple[mx.array, mx.array, mx.array]: |
|
|
"""Retrieve from memory.""" |
|
|
scale = self.d_model ** -0.5 |
|
|
temp = mx.exp(self.log_temp) + 0.1 |
|
|
|
|
|
scores = (query_emb @ key_emb.T) * scale / temp |
|
|
attn = mx.softmax(scores, axis=-1) |
|
|
retrieved = attn @ val_emb |
|
|
|
|
|
return retrieved, attn, scores |
|
|
|
|
|
|
|
|
class Tokenizer: |
|
|
"""Simple tokenizer for MALM.""" |
|
|
|
|
|
def __init__(self, tokenizer_dict: Dict): |
|
|
self.token_to_id = tokenizer_dict.get("token_to_id", {}) |
|
|
self.id_to_token = {int(v): k for k, v in self.token_to_id.items()} |
|
|
self.special = {"<PAD>": 0, "<UNK>": 1, "<BOS>": 2, "<EOS>": 3} |
|
|
|
|
|
def encode(self, text: str) -> List[int]: |
|
|
"""Tokenize text.""" |
|
|
tokens = re.findall(r"[a-zA-Z_][a-zA-Z0-9_]*|[0-9]+|[^\s]", text.lower()) |
|
|
return [self.token_to_id.get(t, self.special.get("<UNK>", 1)) for t in tokens] |
|
|
|
|
|
def decode(self, ids: List[int]) -> str: |
|
|
"""Decode token IDs to text.""" |
|
|
tokens = [self.id_to_token.get(i, "<UNK>") for i in ids] |
|
|
return " ".join(tokens) |
|
|
|
|
|
|
|
|
def load_model(model_dir: Path): |
|
|
"""Load MALM model from directory.""" |
|
|
import mlx.utils as mlx_utils |
|
|
|
|
|
|
|
|
with open(model_dir / "config.json") as f: |
|
|
config = json.load(f) |
|
|
|
|
|
|
|
|
model = MALM( |
|
|
vocab_size=config["vocab_size"], |
|
|
d_model=config["d_model"], |
|
|
n_heads=config["n_heads"], |
|
|
n_layers=config["n_layers"], |
|
|
n_query_layers=config["n_query_layers"], |
|
|
max_seq_len=config["max_seq_len"], |
|
|
) |
|
|
|
|
|
|
|
|
weights = dict(np.load(model_dir / "model.npz")) |
|
|
weights = {k: mx.array(v) for k, v in weights.items()} |
|
|
|
|
|
|
|
|
params = mlx_utils.tree_unflatten(list(weights.items())) |
|
|
model.update(params) |
|
|
mx.eval(model.parameters()) |
|
|
|
|
|
|
|
|
with open(model_dir / "tokenizer.json") as f: |
|
|
tokenizer_dict = json.load(f) |
|
|
tokenizer = Tokenizer(tokenizer_dict) |
|
|
|
|
|
|
|
|
with open(model_dir / "functions.json") as f: |
|
|
functions = json.load(f) |
|
|
|
|
|
return model, tokenizer, functions, config |
|
|
|
|
|
|
|
|
def search_functions( |
|
|
model: MALM, |
|
|
tokenizer: Tokenizer, |
|
|
functions: List[Dict], |
|
|
query: str, |
|
|
top_k: int = 5, |
|
|
) -> List[Tuple[str, str, float]]: |
|
|
"""Search for functions matching a query. |
|
|
|
|
|
Uses the function name as key and signature+docstring as value for retrieval. |
|
|
""" |
|
|
|
|
|
query_ids = tokenizer.encode(query) |
|
|
if not query_ids: |
|
|
query_ids = [1] |
|
|
query_ids = mx.array([query_ids]) |
|
|
|
|
|
|
|
|
key_tokens = [] |
|
|
value_tokens = [] |
|
|
max_val_len = 64 |
|
|
|
|
|
for func in functions: |
|
|
name = func["name"] |
|
|
|
|
|
sig = func.get("signature", name) |
|
|
doc = func.get("docstring", "") |
|
|
value_text = f"{sig} {doc}" |
|
|
|
|
|
key_id = tokenizer.token_to_id.get(name.lower(), 1) |
|
|
key_tokens.append(key_id) |
|
|
|
|
|
val_ids = tokenizer.encode(value_text)[:max_val_len] |
|
|
val_ids = val_ids + [0] * (max_val_len - len(val_ids)) |
|
|
value_tokens.append(val_ids) |
|
|
|
|
|
key_tokens = mx.array(key_tokens) |
|
|
value_tokens = mx.array(value_tokens) |
|
|
|
|
|
|
|
|
key_emb = model.embed(key_tokens) |
|
|
val_emb = model.encode_value(value_tokens) |
|
|
|
|
|
|
|
|
query_emb = model.encode_query(query_ids) |
|
|
_, attn, scores = model.retrieve(query_emb, key_emb, val_emb) |
|
|
mx.eval(scores) |
|
|
|
|
|
|
|
|
scores_np = np.array(scores[0]) |
|
|
top_indices = np.argsort(scores_np)[::-1][:top_k] |
|
|
|
|
|
results = [] |
|
|
for idx in top_indices: |
|
|
func = functions[idx] |
|
|
score = float(scores_np[idx]) |
|
|
sig = func.get("signature", func["name"]) |
|
|
doc = func.get("docstring", "") |
|
|
results.append((func["name"], sig, doc, score)) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="MALM Inference - Semantic Code Search") |
|
|
parser.add_argument("--query", type=str, required=True, help="Natural language query") |
|
|
parser.add_argument("--top-k", type=int, default=5, help="Number of results") |
|
|
parser.add_argument("--model-dir", type=str, default=None, help="Model directory") |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if args.model_dir: |
|
|
model_dir = Path(args.model_dir) |
|
|
else: |
|
|
model_dir = Path(__file__).parent |
|
|
|
|
|
print(f"Loading model from {model_dir}...") |
|
|
model, tokenizer, functions, config = load_model(model_dir) |
|
|
print(f"Loaded {len(functions)} functions, {config['num_parameters']:,} parameters") |
|
|
|
|
|
|
|
|
print(f"\nQuery: {args.query}") |
|
|
print("-" * 60) |
|
|
|
|
|
results = search_functions(model, tokenizer, functions, args.query, args.top_k) |
|
|
|
|
|
for i, (name, signature, docstring, score) in enumerate(results, 1): |
|
|
print(f"\n{i}. {name} (score: {score:.4f})") |
|
|
print(f" Signature: {signature}") |
|
|
if docstring: |
|
|
print(f" Docstring: {docstring[:100]}{'...' if len(docstring) > 100 else ''}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|