ml_trainer_env / test_inference.py
BART-ender's picture
Upload folder using huggingface_hub
8f24287 verified
from types import SimpleNamespace
import pytest
import inference
def make_tool_message(name: str, arguments: str):
return SimpleNamespace(
tool_calls=[
SimpleNamespace(
function=SimpleNamespace(name=name, arguments=arguments)
)
]
)
def test_request_scheduler_enforces_min_gap():
current_time = {"value": 100.0}
sleeps = []
def fake_time():
return current_time["value"]
def fake_sleep(seconds: float):
sleeps.append(seconds)
current_time["value"] += seconds
scheduler = inference.RequestScheduler(
min_gap_seconds=12.5,
rpm_limit=5,
time_fn=fake_time,
sleep_fn=fake_sleep,
)
scheduler.wait_for_turn()
current_time["value"] += 2.0
scheduler.wait_for_turn()
assert sleeps == [pytest.approx(10.5)]
def test_parse_tool_call_accepts_valid_tool_call():
message = make_tool_message("run_epochs", '{"num_epochs": 8}')
tool_call = inference.parse_tool_call(message)
assert tool_call == {"tool_name": "run_epochs", "arguments": {"num_epochs": 8}}
def test_parse_tool_call_rejects_text_only_response():
message = SimpleNamespace(tool_calls=[])
with pytest.raises(inference.InferenceError, match="exactly one tool call"):
inference.parse_tool_call(message)
def test_parse_tool_call_rejects_malformed_arguments():
message = make_tool_message("run_epochs", "{bad json")
with pytest.raises(inference.InferenceError, match="valid JSON"):
inference.parse_tool_call(message)
def test_request_action_retries_rate_limit_with_retry_after(monkeypatch):
sleeps = []
monkeypatch.setattr(inference.time, "sleep", lambda seconds: sleeps.append(seconds))
class DummyRateLimitError(Exception):
def __init__(self, retry_after: str):
self.response = SimpleNamespace(headers={"retry-after": retry_after})
monkeypatch.setattr(inference, "RateLimitError", DummyRateLimitError)
responses = [
DummyRateLimitError("7"),
SimpleNamespace(
choices=[SimpleNamespace(message=make_tool_message("submit_model", "{}"))]
),
]
class FakeCompletions:
def create(self, **kwargs):
response = responses.pop(0)
if isinstance(response, Exception):
raise response
return response
client = SimpleNamespace(chat=SimpleNamespace(completions=FakeCompletions()))
scheduler = inference.RequestScheduler(
min_gap_seconds=0.0,
rpm_limit=5,
time_fn=lambda: 0.0,
sleep_fn=lambda seconds: None,
)
stats = inference.TaskStats()
tool_call = inference.request_action(
client=client,
scheduler=scheduler,
messages=[{"role": "user", "content": "test"}],
stats=stats,
decision_index=1,
max_decisions=3,
)
assert tool_call["tool_name"] == "submit_model"
assert stats.requests == 2
assert stats.retries == 1
assert sleeps == [7.0]
def test_extract_reset_metadata_reads_wrapped_result_payload():
payload = {
"observation": {
"metadata": {},
"result": {
"data": {
"task_name": "MNIST Digit Classifier",
"difficulty": "easy",
"dataset": "mnist",
"max_epochs": 100,
}
},
}
}
metadata = inference.extract_reset_metadata(payload)
assert metadata["task_name"] == "MNIST Digit Classifier"
assert metadata["difficulty"] == "easy"
def test_extract_result_data_reads_json_content_text():
payload = {
"observation": {
"result": {
"content": [
{
"type": "text",
"text": '{"status":"configured","metrics":{"current_epoch":0,"remaining_budget":100,"train_loss":0.0,"val_loss":0.0,"train_accuracy":0.0,"val_accuracy":0.0,"best_val_accuracy":0.0,"convergence_signal":"not_started","is_diverged":false}}',
}
]
}
}
}
result = inference.extract_result_data(payload)
normalized = inference.normalize_tool_result(result)
assert normalized["current_epoch"] == 0
assert normalized["remaining_budget"] == 100
assert normalized["status"] == "configured"
def test_run_task_forces_configure_then_submit(monkeypatch):
monkeypatch.setitem(inference.LLM_MAX_STEPS, "easy_mnist", 2)
tool_choices = []
responses = [
SimpleNamespace(
choices=[
SimpleNamespace(
message=make_tool_message(
"configure_training",
'{"optimizer":"adam","learning_rate":0.001,"batch_size":64,"weight_decay":0.0,"dropout":0.0,"lr_schedule":"cosine","warmup_epochs":3,"augmentation":false,"augmentation_strength":0.0}',
)
)
]
),
SimpleNamespace(
choices=[SimpleNamespace(message=make_tool_message("submit_model", "{}"))]
),
]
class FakeCompletions:
def create(self, **kwargs):
tool_choices.append(kwargs["tool_choice"])
return responses.pop(0)
client = SimpleNamespace(chat=SimpleNamespace(completions=FakeCompletions()))
scheduler = inference.RequestScheduler(
min_gap_seconds=0.0,
rpm_limit=5,
time_fn=lambda: 0.0,
sleep_fn=lambda seconds: None,
)
class FakeEnv:
def sync(self):
return self
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return None
def reset(self, **kwargs):
return SimpleNamespace(
observation=SimpleNamespace(
metadata={
"task_id": kwargs["task_id"],
"task_name": "MNIST Digit Classifier",
"task_description": "desc",
"difficulty": "easy",
"model_type": "simple_mlp",
"dataset": "mnist",
"max_epochs": 100,
"target_metric": "val_accuracy",
"target_value": 0.96,
}
)
)
def step(self, action):
if action.tool_name == "configure_training":
return SimpleNamespace(
observation=SimpleNamespace(
metadata={},
result={
"data": {
"status": "configured",
"metrics": {
"current_epoch": 0,
"remaining_budget": 100,
"train_loss": 0.0,
"val_loss": 0.0,
"train_accuracy": 0.0,
"val_accuracy": 0.0,
"best_val_accuracy": 0.0,
"convergence_signal": "not_started",
"is_diverged": False,
"current_config": {
"optimizer": "adam",
"learning_rate": 0.001,
},
},
}
},
),
reward=0.0,
done=False,
)
return SimpleNamespace(
observation=SimpleNamespace(
metadata={},
result={"data": {"grade": {"score": 0.8}}},
),
reward=0.8,
done=True,
)
monkeypatch.setattr(inference, "MLTrainerEnv", lambda base_url: FakeEnv())
result = inference.run_task(client, scheduler, "easy_mnist")
assert tool_choices[0] == {"type": "function", "function": {"name": "configure_training"}}
assert tool_choices[-1] == {"type": "function", "function": {"name": "submit_model"}}
assert result["llm_decisions"] == 2
assert result["final_score"] == 0.8
def test_run_task_reuses_same_env_session(monkeypatch):
calls = []
class FakeCompletions:
def create(self, **kwargs):
return SimpleNamespace(
choices=[SimpleNamespace(message=make_tool_message("submit_model", "{}"))]
)
client = SimpleNamespace(chat=SimpleNamespace(completions=FakeCompletions()))
scheduler = inference.RequestScheduler(
min_gap_seconds=0.0,
rpm_limit=5,
time_fn=lambda: 0.0,
sleep_fn=lambda seconds: None,
)
monkeypatch.setitem(inference.LLM_MAX_STEPS, "easy_mnist", 1)
class FakeEnv:
def sync(self):
return self
def __enter__(self):
calls.append("enter")
return self
def __exit__(self, exc_type, exc, tb):
calls.append("exit")
return None
def reset(self, **kwargs):
calls.append(("reset", kwargs["task_id"]))
return SimpleNamespace(
observation=SimpleNamespace(
metadata={
"task_id": kwargs["task_id"],
"task_name": "MNIST Digit Classifier",
"difficulty": "easy",
"dataset": "mnist",
"max_epochs": 100,
"target_metric": "val_accuracy",
"target_value": 0.96,
}
)
)
def step(self, action):
calls.append(("step", action.tool_name))
return SimpleNamespace(
observation=SimpleNamespace(
metadata={},
result={"data": {"grade": {"score": 0.7}}},
),
reward=0.7,
done=True,
)
monkeypatch.setattr(inference, "MLTrainerEnv", lambda base_url: FakeEnv())
result = inference.run_task(client, scheduler, "easy_mnist")
assert result["final_score"] == 0.7
assert calls == ["enter", ("reset", "easy_mnist"), ("step", "submit_model"), "exit"]
def test_request_action_does_not_send_seed(monkeypatch):
captured_kwargs = {}
class FakeCompletions:
def create(self, **kwargs):
captured_kwargs.update(kwargs)
return SimpleNamespace(
choices=[SimpleNamespace(message=make_tool_message("submit_model", "{}"))]
)
client = SimpleNamespace(chat=SimpleNamespace(completions=FakeCompletions()))
scheduler = inference.RequestScheduler(
min_gap_seconds=0.0,
rpm_limit=5,
time_fn=lambda: 0.0,
sleep_fn=lambda seconds: None,
)
stats = inference.TaskStats()
inference.request_action(
client=client,
scheduler=scheduler,
messages=[{"role": "user", "content": "test"}],
stats=stats,
decision_index=1,
max_decisions=3,
)
assert "seed" not in captured_kwargs
def test_apply_tool_context_preserves_current_config():
configured = inference.apply_tool_context(
"configure_training",
{"optimizer": "adam", "learning_rate": 0.001, "batch_size": 64},
{},
{"current_epoch": 0, "remaining_budget": 100},
)
assert configured["current_config"]["optimizer"] == "adam"
updated = inference.apply_tool_context(
"run_epochs",
{"num_epochs": 10},
configured,
{"current_epoch": 10, "best_val_accuracy": 0.95},
)
assert updated["current_config"]["batch_size"] == 64
assert updated["best_val_accuracy"] == 0.95
def test_merge_task_metadata_fills_missing_fields():
merged = inference.merge_task_metadata("easy_mnist", {"task_id": "easy_mnist"})
assert merged["difficulty"] == "easy"
assert merged["dataset"] == "mnist"
assert merged["max_epochs"] == 100