|
|
""" |
|
|
Memory Routing RL Environment |
|
|
|
|
|
This implements the MemoryRoutingEnv for Stage 2 (RL Optimization) per PRD Section 8. |
|
|
|
|
|
Per Tinker docs (rl/rl-envs.mdx): |
|
|
- Env operates on tokens, not strings |
|
|
- Implement initial_observation() and step() |
|
|
- EnvGroupBuilder creates groups of environments |
|
|
- RLDataset provides batches of EnvGroupBuilders |
|
|
|
|
|
Per PRD Section 4 (Reward Computation): |
|
|
- R_F1: Token-level F1 between predicted and gold categories |
|
|
- R_temp: Persistence alignment (+1.0 exact, +0.5 adjacent, 0.0 otherwise) |
|
|
- R_parity: Company/user scope alignment |
|
|
- R_eff: Storage efficiency (penalize >3 categories) |
|
|
- R_total = 0.6*R_F1 + 0.2*R_temp + 0.1*R_parity + 0.1*R_eff |
|
|
|
|
|
Per PRD Section 4 (Environment Design): |
|
|
- Single-step bandit: initial_observation returns conversation, step terminates |
|
|
- EnvGroupBuilder clones each conversation across group_size rollouts |
|
|
""" |
|
|
|
|
|
import json |
|
|
from typing import List, Dict, Any, Tuple, Set, Optional, Sequence |
|
|
from dataclasses import dataclass |
|
|
|
|
|
|
|
|
VALID_CATEGORIES = { |
|
|
"company.brand_core", |
|
|
"company.strategic_signatures", |
|
|
"company.knowledge_artifacts", |
|
|
"company.business_priorities", |
|
|
"company.tools_config", |
|
|
"company.performance_context", |
|
|
"user.communication_style", |
|
|
"user.strategic_approach", |
|
|
"user.role_context", |
|
|
"user.workflow_patterns", |
|
|
"user.session_history", |
|
|
"user.interaction_preferences", |
|
|
"none" |
|
|
} |
|
|
|
|
|
|
|
|
CATEGORY_PERSISTENCE = { |
|
|
"company.brand_core": "long", |
|
|
"company.strategic_signatures": "long", |
|
|
"company.knowledge_artifacts": "long", |
|
|
"company.business_priorities": "short", |
|
|
"company.tools_config": "medium", |
|
|
"company.performance_context": "rolling", |
|
|
"user.communication_style": "long", |
|
|
"user.strategic_approach": "long", |
|
|
"user.role_context": "medium", |
|
|
"user.workflow_patterns": "medium", |
|
|
"user.session_history": "short", |
|
|
"user.interaction_preferences": "evolving", |
|
|
"none": "short" |
|
|
} |
|
|
|
|
|
|
|
|
CATEGORY_SCOPE = { |
|
|
"company.brand_core": "company", |
|
|
"company.strategic_signatures": "company", |
|
|
"company.knowledge_artifacts": "company", |
|
|
"company.business_priorities": "company", |
|
|
"company.tools_config": "company", |
|
|
"company.performance_context": "company", |
|
|
"user.communication_style": "user", |
|
|
"user.strategic_approach": "user", |
|
|
"user.role_context": "user", |
|
|
"user.workflow_patterns": "user", |
|
|
"user.session_history": "user", |
|
|
"user.interaction_preferences": "user", |
|
|
"none": "none" |
|
|
} |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class RewardComponents: |
|
|
"""Breakdown of reward computation.""" |
|
|
r_f1: float = 0.0 |
|
|
r_temp: float = 0.0 |
|
|
r_parity: float = 0.0 |
|
|
r_eff: float = 0.0 |
|
|
r_total: float = 0.0 |
|
|
format_valid: bool = True |
|
|
|
|
|
|
|
|
def parse_categories(text: str) -> Tuple[Set[str], bool]: |
|
|
""" |
|
|
Parse comma-separated categories from model output. |
|
|
|
|
|
Returns: |
|
|
(set of valid categories, parse_success) |
|
|
""" |
|
|
if not text or not text.strip(): |
|
|
return set(), False |
|
|
|
|
|
|
|
|
raw_cats = [c.strip().lower() for c in text.split(",")] |
|
|
|
|
|
|
|
|
valid_cats = {c for c in raw_cats if c in VALID_CATEGORIES} |
|
|
|
|
|
if not valid_cats: |
|
|
return set(), False |
|
|
|
|
|
|
|
|
|
|
|
if "none" in valid_cats and len(valid_cats) > 1: |
|
|
valid_cats.discard("none") |
|
|
|
|
|
return valid_cats, True |
|
|
|
|
|
|
|
|
def compute_f1(predicted: Set[str], gold: Set[str]) -> float: |
|
|
""" |
|
|
Compute F1 score between predicted and gold category sets. |
|
|
|
|
|
Per PRD: Use macro-averaging if multi-label. |
|
|
""" |
|
|
if not predicted and not gold: |
|
|
return 1.0 |
|
|
if not predicted or not gold: |
|
|
return 0.0 |
|
|
|
|
|
true_positives = len(predicted & gold) |
|
|
precision = true_positives / len(predicted) if predicted else 0.0 |
|
|
recall = true_positives / len(gold) if gold else 0.0 |
|
|
|
|
|
if precision + recall == 0: |
|
|
return 0.0 |
|
|
|
|
|
return 2 * (precision * recall) / (precision + recall) |
|
|
|
|
|
|
|
|
def compute_temporal_reward(predicted: Set[str], gold: Set[str]) -> float: |
|
|
""" |
|
|
Compute temporal alignment reward. |
|
|
|
|
|
Per PRD: |
|
|
- +1.0 if predicted persistence matches gold |
|
|
- +0.5 if adjacent (long<->medium or medium<->short) |
|
|
- 0.0 otherwise |
|
|
- Use majority vote if multi-label |
|
|
""" |
|
|
if not predicted or not gold: |
|
|
return 0.0 |
|
|
|
|
|
|
|
|
pred_persistence = [CATEGORY_PERSISTENCE.get(c, "medium") for c in predicted] |
|
|
gold_persistence = [CATEGORY_PERSISTENCE.get(c, "medium") for c in gold] |
|
|
|
|
|
|
|
|
def majority(items): |
|
|
from collections import Counter |
|
|
if not items: |
|
|
return "medium" |
|
|
counts = Counter(items) |
|
|
return counts.most_common(1)[0][0] |
|
|
|
|
|
pred_pers = majority(pred_persistence) |
|
|
gold_pers = majority(gold_persistence) |
|
|
|
|
|
|
|
|
if pred_pers == gold_pers: |
|
|
return 1.0 |
|
|
|
|
|
|
|
|
adjacency = { |
|
|
("long", "medium"): True, |
|
|
("medium", "long"): True, |
|
|
("medium", "short"): True, |
|
|
("short", "medium"): True, |
|
|
("medium", "rolling"): True, |
|
|
("rolling", "medium"): True, |
|
|
("short", "rolling"): True, |
|
|
("rolling", "short"): True, |
|
|
} |
|
|
|
|
|
if (pred_pers, gold_pers) in adjacency: |
|
|
return 0.5 |
|
|
|
|
|
return 0.0 |
|
|
|
|
|
|
|
|
def compute_parity_reward(predicted: Set[str], gold: Set[str]) -> float: |
|
|
""" |
|
|
Compute company/user scope alignment reward. |
|
|
|
|
|
Per PRD: |
|
|
- +1.0 if predicted scope matches gold scope exactly |
|
|
- 0.0 otherwise |
|
|
""" |
|
|
def get_scope(categories: Set[str]) -> str: |
|
|
scopes = {CATEGORY_SCOPE.get(c, "none") for c in categories} |
|
|
if "company" in scopes and "user" in scopes: |
|
|
return "mixed" |
|
|
elif "company" in scopes: |
|
|
return "company" |
|
|
elif "user" in scopes: |
|
|
return "user" |
|
|
else: |
|
|
return "none" |
|
|
|
|
|
pred_scope = get_scope(predicted) |
|
|
gold_scope = get_scope(gold) |
|
|
|
|
|
return 1.0 if pred_scope == gold_scope else 0.0 |
|
|
|
|
|
|
|
|
def compute_efficiency_reward(predicted: Set[str]) -> float: |
|
|
""" |
|
|
Compute storage efficiency reward. |
|
|
|
|
|
Per PRD: |
|
|
- 1.0 if ≤3 categories |
|
|
- 0.7 if 4 categories |
|
|
- 0.4 if 5 categories |
|
|
- 0.0 if ≥6 categories |
|
|
""" |
|
|
n = len(predicted) |
|
|
if n <= 3: |
|
|
return 1.0 |
|
|
elif n == 4: |
|
|
return 0.7 |
|
|
elif n == 5: |
|
|
return 0.4 |
|
|
else: |
|
|
return 0.0 |
|
|
|
|
|
|
|
|
def compute_reward(predicted_text: str, gold_categories: List[str]) -> RewardComponents: |
|
|
""" |
|
|
Compute full reward for a prediction. |
|
|
|
|
|
Per PRD Section 4: |
|
|
R_total = 0.6 * R_F1 + 0.2 * R_temp + 0.1 * R_parity + 0.1 * R_eff |
|
|
|
|
|
Returns RewardComponents with breakdown. |
|
|
""" |
|
|
result = RewardComponents() |
|
|
|
|
|
|
|
|
predicted, parse_success = parse_categories(predicted_text) |
|
|
gold = set(gold_categories) |
|
|
|
|
|
|
|
|
if not parse_success: |
|
|
result.format_valid = False |
|
|
result.r_total = -1.0 |
|
|
return result |
|
|
|
|
|
|
|
|
result.r_f1 = compute_f1(predicted, gold) |
|
|
result.r_temp = compute_temporal_reward(predicted, gold) |
|
|
result.r_parity = compute_parity_reward(predicted, gold) |
|
|
result.r_eff = compute_efficiency_reward(predicted) |
|
|
|
|
|
|
|
|
result.r_total = ( |
|
|
0.6 * result.r_f1 + |
|
|
0.2 * result.r_temp + |
|
|
0.1 * result.r_parity + |
|
|
0.1 * result.r_eff |
|
|
) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MemoryRoutingEnv: |
|
|
""" |
|
|
Single-step bandit environment for memory routing. |
|
|
|
|
|
Per Tinker Env interface: |
|
|
- initial_observation() -> (Observation, StopCondition) |
|
|
- step(action) -> StepResult |
|
|
|
|
|
Per PRD: Single-step episodes - step() terminates immediately with reward. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
conversation: List[Dict[str, str]], |
|
|
gold_categories: List[str], |
|
|
prompt_tokens: List[int], |
|
|
stop_tokens: List[int], |
|
|
scenario_id: str = "" |
|
|
): |
|
|
self.conversation = conversation |
|
|
self.gold_categories = gold_categories |
|
|
self.prompt_tokens = prompt_tokens |
|
|
self.stop_tokens = stop_tokens |
|
|
self.scenario_id = scenario_id |
|
|
self._done = False |
|
|
|
|
|
async def initial_observation(self): |
|
|
""" |
|
|
Return the initial observation (prompt tokens) and stop condition. |
|
|
|
|
|
Per Tinker: Returns (Observation, StopCondition) |
|
|
- Observation is the model input (tokens) |
|
|
- StopCondition tells the sampler when to stop |
|
|
""" |
|
|
from tinker import types |
|
|
from tinker_cookbook.rl.types import StopCondition |
|
|
|
|
|
observation = types.ModelInput.from_ints(self.prompt_tokens) |
|
|
stop_condition = StopCondition(stop_tokens=self.stop_tokens) |
|
|
|
|
|
return observation, stop_condition |
|
|
|
|
|
async def step(self, action): |
|
|
""" |
|
|
Process the model's action (generated tokens) and return reward. |
|
|
|
|
|
Per Tinker: Returns StepResult with reward and done=True |
|
|
Per PRD: Single-step bandit, so always terminates. |
|
|
""" |
|
|
from tinker_cookbook.rl.types import StepResult |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(action, list): |
|
|
|
|
|
action_text = str(action) |
|
|
else: |
|
|
action_text = str(action) |
|
|
|
|
|
|
|
|
reward_components = compute_reward(action_text, self.gold_categories) |
|
|
|
|
|
self._done = True |
|
|
|
|
|
return StepResult( |
|
|
reward=reward_components.r_total, |
|
|
done=True, |
|
|
info={ |
|
|
"r_f1": reward_components.r_f1, |
|
|
"r_temp": reward_components.r_temp, |
|
|
"r_parity": reward_components.r_parity, |
|
|
"r_eff": reward_components.r_eff, |
|
|
"format_valid": reward_components.format_valid, |
|
|
"scenario_id": self.scenario_id |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
class MemoryRoutingEnvGroupBuilder: |
|
|
""" |
|
|
Builds a group of identical environments for variance reduction. |
|
|
|
|
|
Per Tinker docs (rl/rl-envs.mdx): |
|
|
- EnvGroupBuilder creates group_size copies of the same environment |
|
|
- This enables comparing multiple samples for the same input |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
conversation: List[Dict[str, str]], |
|
|
gold_categories: List[str], |
|
|
prompt_tokens: List[int], |
|
|
stop_tokens: List[int], |
|
|
group_size: int = 8, |
|
|
scenario_id: str = "" |
|
|
): |
|
|
self.conversation = conversation |
|
|
self.gold_categories = gold_categories |
|
|
self.prompt_tokens = prompt_tokens |
|
|
self.stop_tokens = stop_tokens |
|
|
self.group_size = group_size |
|
|
self.scenario_id = scenario_id |
|
|
|
|
|
async def make_envs(self) -> Sequence["MemoryRoutingEnv"]: |
|
|
"""Create group_size copies of the environment.""" |
|
|
return [ |
|
|
MemoryRoutingEnv( |
|
|
conversation=self.conversation, |
|
|
gold_categories=self.gold_categories, |
|
|
prompt_tokens=self.prompt_tokens, |
|
|
stop_tokens=self.stop_tokens, |
|
|
scenario_id=self.scenario_id |
|
|
) |
|
|
for _ in range(self.group_size) |
|
|
] |
|
|
|
|
|
def logging_tags(self) -> Dict[str, Any]: |
|
|
"""Return tags for logging.""" |
|
|
return { |
|
|
"scenario_id": self.scenario_id, |
|
|
"num_gold_categories": len(self.gold_categories), |
|
|
"has_none": "none" in self.gold_categories |
|
|
} |
|
|
|
|
|
|
|
|
class MemoryRoutingDataset: |
|
|
""" |
|
|
Dataset of EnvGroupBuilders for RL training. |
|
|
|
|
|
Per Tinker docs (rl/rl-envs.mdx): |
|
|
- RLDataset.get_batch(index) returns list of EnvGroupBuilders |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
examples: List[Dict[str, Any]], |
|
|
batch_size: int, |
|
|
group_size: int, |
|
|
renderer, |
|
|
tokenizer |
|
|
): |
|
|
self.examples = examples |
|
|
self.batch_size = batch_size |
|
|
self.group_size = group_size |
|
|
self.renderer = renderer |
|
|
self.tokenizer = tokenizer |
|
|
self.stop_tokens = renderer.get_stop_sequences() |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.examples) // self.batch_size |
|
|
|
|
|
def get_batch(self, index: int) -> List[MemoryRoutingEnvGroupBuilder]: |
|
|
"""Get a batch of EnvGroupBuilders.""" |
|
|
start_idx = (index * self.batch_size) % len(self.examples) |
|
|
end_idx = start_idx + self.batch_size |
|
|
|
|
|
if end_idx <= len(self.examples): |
|
|
batch_examples = self.examples[start_idx:end_idx] |
|
|
else: |
|
|
batch_examples = self.examples[start_idx:] |
|
|
batch_examples.extend(self.examples[:end_idx - len(self.examples)]) |
|
|
|
|
|
builders = [] |
|
|
for example in batch_examples: |
|
|
|
|
|
messages = example.get("messages", []) |
|
|
if not messages: |
|
|
|
|
|
conversation = example.get("conversation", []) |
|
|
categories = example.get("labels", {}).get("categories", []) |
|
|
|
|
|
from training.preprocess import build_routing_prompt |
|
|
full_messages = build_routing_prompt(conversation, categories) |
|
|
|
|
|
messages = full_messages[:-1] |
|
|
|
|
|
|
|
|
prompt = self.renderer.build_generation_prompt(messages) |
|
|
prompt_tokens = prompt.to_ints() |
|
|
|
|
|
|
|
|
gold_categories = example.get("categories", []) |
|
|
if not gold_categories: |
|
|
gold_categories = example.get("labels", {}).get("categories", []) |
|
|
|
|
|
builders.append(MemoryRoutingEnvGroupBuilder( |
|
|
conversation=example.get("conversation", []), |
|
|
gold_categories=gold_categories, |
|
|
prompt_tokens=prompt_tokens, |
|
|
stop_tokens=self.stop_tokens, |
|
|
group_size=self.group_size, |
|
|
scenario_id=example.get("scenario_id", "") |
|
|
)) |
|
|
|
|
|
return builders |
|
|
|
|
|
|
|
|
class MemoryRoutingDatasetBuilder: |
|
|
""" |
|
|
Factory for creating train/test RL datasets. |
|
|
|
|
|
Per Tinker pattern from math_env.py example. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
train_data_path: str, |
|
|
test_data_path: str, |
|
|
batch_size: int = 64, |
|
|
group_size: int = 8, |
|
|
model_name: str = "meta-llama/Llama-3.1-8B", |
|
|
renderer_name: str = "llama3" |
|
|
): |
|
|
self.train_data_path = train_data_path |
|
|
self.test_data_path = test_data_path |
|
|
self.batch_size = batch_size |
|
|
self.group_size = group_size |
|
|
self.model_name = model_name |
|
|
self.renderer_name = renderer_name |
|
|
|
|
|
def __call__(self) -> Tuple[MemoryRoutingDataset, MemoryRoutingDataset]: |
|
|
"""Create train and test datasets.""" |
|
|
from tinker_cookbook import renderers, tokenizer_utils |
|
|
|
|
|
tokenizer = tokenizer_utils.get_tokenizer(self.model_name) |
|
|
renderer = renderers.get_renderer(name=self.renderer_name, tokenizer=tokenizer) |
|
|
|
|
|
|
|
|
with open(self.train_data_path, "r") as f: |
|
|
train_examples = json.load(f) |
|
|
|
|
|
with open(self.test_data_path, "r") as f: |
|
|
test_examples = json.load(f) |
|
|
|
|
|
train_dataset = MemoryRoutingDataset( |
|
|
examples=train_examples, |
|
|
batch_size=self.batch_size, |
|
|
group_size=self.group_size, |
|
|
renderer=renderer, |
|
|
tokenizer=tokenizer |
|
|
) |
|
|
|
|
|
test_dataset = MemoryRoutingDataset( |
|
|
examples=test_examples, |
|
|
batch_size=min(self.batch_size, len(test_examples)), |
|
|
group_size=self.group_size, |
|
|
renderer=renderer, |
|
|
tokenizer=tokenizer |
|
|
) |
|
|
|
|
|
return train_dataset, test_dataset |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
test_cases = [ |
|
|
|
|
|
("company.brand_core, user.strategic_approach", ["company.brand_core", "user.strategic_approach"], True), |
|
|
("none", ["none"], True), |
|
|
("company.brand_core, none", ["company.brand_core"], True), |
|
|
("invalid_category", ["company.brand_core"], False), |
|
|
("", ["company.brand_core"], False), |
|
|
("company.brand_core", ["company.brand_core", "user.role_context"], True), |
|
|
] |
|
|
|
|
|
print("Testing reward computation:") |
|
|
print("=" * 60) |
|
|
|
|
|
for pred, gold, expected_valid in test_cases: |
|
|
result = compute_reward(pred, gold) |
|
|
print(f"\nPredicted: '{pred}'") |
|
|
print(f"Gold: {gold}") |
|
|
print(f"Format valid: {result.format_valid} (expected: {expected_valid})") |
|
|
print(f"R_F1: {result.r_f1:.3f}") |
|
|
print(f"R_temp: {result.r_temp:.3f}") |
|
|
print(f"R_parity: {result.r_parity:.3f}") |
|
|
print(f"R_eff: {result.r_eff:.3f}") |
|
|
print(f"R_total: {result.r_total:.3f}") |
|
|
|
|
|
|