RiverRider's picture
Initial release: SRT-Adapter v8a (peer-review distribution)
aa2d4f1 verified

Examples

load_and_score.py

End-to-end demo. Loads the v8a adapter on top of a frozen Qwen/Qwen2.5-7B and prints all four semiotic readouts for an input passage.

cd examples
pip install -r ../requirements.txt
python load_and_score.py --text "Vaccine mandates are an obvious public health win."

First run downloads Qwen/Qwen2.5-7B (~15 GB) from HuggingFace.

Programmatic use

import json, sys, torch
from pathlib import Path
sys.path.insert(0, "src")

from srt.config import (SRTConfig, MAHConfig, RRMConfig, BENConfig,
                        CommunityConfig, LossConfig)
from srt.adapter import SRTAdapter
from transformers import AutoTokenizer

raw = json.loads(Path("config.json").read_text())
config = SRTConfig(
    backbone_id        = raw["backbone_id"],
    backbone_dtype     = raw["backbone_dtype"],
    mah_layer_indices  = list(raw["mah_layer_indices"]),
    rrm_inject_indices = list(raw["rrm_inject_indices"]),
    community_layer_idx= raw["community_layer_idx"],
    num_mah_layers     = raw["num_mah_layers"],
    mah                = MAHConfig(**raw["mah"]),
    rrm                = RRMConfig(**raw["rrm"]),
    ben                = BENConfig(**raw["ben"]),
    community          = CommunityConfig(**raw["community"]),
    loss               = LossConfig(**{k: v for k, v in raw["loss"].items()
                                       if k in LossConfig.__dataclass_fields__}),
)

model = SRTAdapter(config).cuda().eval()
state = torch.load("adapter.pt", map_location="cuda")
model.load_state_dict(state, strict=False)

tok = AutoTokenizer.from_pretrained(config.backbone_id)
enc = tok("Freedom means different things to different people.",
          return_tensors="pt").to("cuda")

with torch.no_grad():
    out = model(input_ids=enc.input_ids, attention_mask=enc.attention_mask)

print("logits         :", out.logits.shape)                   # (1, T, V)
print("community vec  :", out.community_output.vector.shape)  # (1, 64)
print("divergences    :", [d.shape for d in out.divergences]) # 3× (1, T, 256)
print("r_hat          :", out.ben_output.r_hat.shape)         # (1, T)
print("regime logits  :", out.ben_output.regime_logits.shape) # (1, T, 2)