"""Run orchestration engines and reproducible execution helpers."""
from __future__ import annotations
import importlib
import random
import time
from collections.abc import Callable, Mapping, Sequence
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from .adapters.agents import execute_agent
from .adapters.problems import evaluate_problem, resolve_problem
from .artifacts import (
checkpoint_run_result,
export_canonical_artifacts,
load_checkpointed_run_results,
)
from .conditions import Condition
from .designs import build_design
from .metrics import compose_metrics
from .schemas import RunStatus, ValidationError, hash_identifier, stable_json_dumps, utc_now_iso
from .study import RunResult, RunSpec, Study, validate_study
@dataclass(slots=True)
class DryRunReport:
"""Dry-run validation report for a planned execution."""
errors: list[str]
planned_runs: int
class SerialRunner:
"""Serial orchestration runner."""
def run(
self,
*,
run_specs: Sequence[RunSpec],
condition_by_id: Mapping[str, Condition],
agent_factories: Mapping[str, Callable[[Condition], Any]] | None,
problem_registry: Mapping[str, Any] | None,
output_dir: Path,
checkpoint: bool,
fail_fast: bool,
) -> list[RunResult]:
"""Execute all run specs one-by-one."""
results: list[RunResult] = []
for run_spec in run_specs:
condition = condition_by_id[run_spec.condition_id]
result = _execute_single_run(
run_spec=run_spec,
condition=condition,
agent_factories=agent_factories,
problem_registry=problem_registry,
)
results.append(result)
if checkpoint:
checkpoint_run_result(result, output_dir=output_dir)
if fail_fast and result.status == RunStatus.FAILED:
break
return results
class LocalParallelRunner:
"""Thread-based local parallel orchestration runner."""
def run(
self,
*,
run_specs: Sequence[RunSpec],
condition_by_id: Mapping[str, Condition],
agent_factories: Mapping[str, Callable[[Condition], Any]] | None,
problem_registry: Mapping[str, Any] | None,
output_dir: Path,
checkpoint: bool,
fail_fast: bool,
max_workers: int,
) -> list[RunResult]:
"""Execute run specs in a local thread pool."""
results: list[RunResult] = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_by_run_id: dict[Future[RunResult], str] = {}
for run_spec in run_specs:
condition = condition_by_id[run_spec.condition_id]
future = executor.submit(
_execute_single_run,
run_spec=run_spec,
condition=condition,
agent_factories=agent_factories,
problem_registry=problem_registry,
)
future_by_run_id[future] = run_spec.run_id
for future in as_completed(future_by_run_id):
result = future.result()
results.append(result)
if checkpoint:
checkpoint_run_result(result, output_dir=output_dir)
if fail_fast and result.status == RunStatus.FAILED:
for pending_future in future_by_run_id:
if pending_future.done():
continue
pending_future.cancel()
break
return results
[docs]
def run_study(
study: Study,
*,
conditions: Sequence[Condition] | None = None,
agent_factories: Mapping[str, Callable[[Condition], Any]] | None = None,
problem_registry: Mapping[str, Any] | None = None,
parallelism: int | None = None,
dry_run: bool = False,
resume: bool = False,
checkpoint: bool = True,
include_sqlite: bool = False,
) -> list[RunResult]:
"""Run a study end-to-end and export canonical artifacts."""
resolved_conditions = list(conditions) if conditions is not None else build_design(study)
report = dry_run_validate(study, conditions=resolved_conditions)
if report.errors:
raise ValidationError("\n".join(report.errors))
if dry_run:
return []
output_dir = Path(study.output_dir or Path("artifacts") / study.study_id)
output_dir.mkdir(parents=True, exist_ok=True)
all_run_specs = _build_run_specs(study=study, conditions=resolved_conditions)
condition_by_id = {condition.condition_id: condition for condition in resolved_conditions}
existing_results: list[RunResult] = []
completed_run_ids: set[str] = set()
if resume:
checkpointed = load_checkpointed_run_results(output_dir)
existing_results = list(checkpointed.values())
completed_run_ids = set(checkpointed)
pending_run_specs = [
run_spec for run_spec in all_run_specs if run_spec.run_id not in completed_run_ids
]
resolved_parallelism = parallelism if parallelism is not None else study.run_budget.parallelism
if resolved_parallelism < 1:
raise ValidationError("parallelism must be >= 1.")
if resolved_parallelism == 1:
serial_runner = SerialRunner()
new_results = serial_runner.run(
run_specs=pending_run_specs,
condition_by_id=condition_by_id,
agent_factories=agent_factories,
problem_registry=problem_registry,
output_dir=output_dir,
checkpoint=checkpoint,
fail_fast=study.run_budget.fail_fast,
)
else:
parallel_runner = LocalParallelRunner()
new_results = parallel_runner.run(
run_specs=pending_run_specs,
condition_by_id=condition_by_id,
agent_factories=agent_factories,
problem_registry=problem_registry,
output_dir=output_dir,
checkpoint=checkpoint,
fail_fast=study.run_budget.fail_fast,
max_workers=resolved_parallelism,
)
all_results = existing_results + new_results
export_canonical_artifacts(
study=study,
conditions=resolved_conditions,
run_results=all_results,
output_dir=output_dir,
include_sqlite=include_sqlite,
)
return all_results
[docs]
def resume_study(
study: Study,
*,
conditions: Sequence[Condition] | None = None,
agent_factories: Mapping[str, Callable[[Condition], Any]] | None = None,
problem_registry: Mapping[str, Any] | None = None,
parallelism: int | None = None,
checkpoint: bool = True,
include_sqlite: bool = False,
) -> list[RunResult]:
"""Resume a study from checkpointed run results."""
return run_study(
study,
conditions=conditions,
agent_factories=agent_factories,
problem_registry=problem_registry,
parallelism=parallelism,
dry_run=False,
resume=True,
checkpoint=checkpoint,
include_sqlite=include_sqlite,
)
def dry_run_validate(study: Study, *, conditions: Sequence[Condition]) -> DryRunReport:
"""Validate run inputs before launching execution."""
errors = list(validate_study(study))
admissible_conditions = [condition for condition in conditions if condition.admissible]
if not admissible_conditions:
errors.append("No admissible conditions are available to execute.")
planned_runs = len(_build_run_specs(study=study, conditions=conditions))
if planned_runs < 1:
errors.append("No run specifications were generated.")
if study.run_budget.max_runs is not None and planned_runs > study.run_budget.max_runs:
errors.append(
"Planned runs "
f"({planned_runs}) exceed run budget max_runs ({study.run_budget.max_runs})."
)
return DryRunReport(errors=errors, planned_runs=planned_runs)
def reproducible_seed(seed: int) -> Any:
"""Context manager preserving and restoring random RNG state."""
return _reproducible_seed(seed)
@contextmanager
def _reproducible_seed(seed: int) -> Any:
"""Temporarily set deterministic RNG seeds for one run."""
previous_state = random.getstate()
random.seed(seed)
numpy_state: tuple[Any, Any] | None = None
try:
numpy_module = importlib.import_module("numpy")
numpy_random = numpy_module.random
get_state = numpy_random.get_state
set_state = numpy_random.set_state
seed_function = numpy_random.seed
numpy_state = (get_state(), set_state)
seed_function(seed)
except Exception:
numpy_state = None
try:
yield
finally:
random.setstate(previous_state)
if numpy_state is not None:
state, set_state = numpy_state
set_state(state)
def _build_run_specs(study: Study, conditions: Sequence[Condition]) -> list[RunSpec]:
"""Materialize deterministic run specs for all admissible conditions."""
admissible_conditions = [condition for condition in conditions if condition.admissible]
run_specs: list[RunSpec] = []
for condition in admissible_conditions:
agent_ids = _resolve_agent_ids(study=study, condition=condition)
problem_ids = _resolve_problem_ids(study=study, condition=condition)
for replicate in range(1, study.run_budget.replicates + 1):
for agent_id in agent_ids:
for problem_id in problem_ids:
run_id = hash_identifier(
"run",
{
"study_id": study.study_id,
"condition_id": condition.condition_id,
"replicate": replicate,
"agent_id": agent_id,
"problem_id": problem_id,
},
)
seed = study.seed_policy.derive_seed(
study_id=study.study_id,
condition_id=condition.condition_id,
replicate=replicate,
salt=f"{agent_id}:{problem_id}",
)
run_specs.append(
RunSpec(
run_id=run_id,
study_id=study.study_id,
condition_id=condition.condition_id,
problem_id=problem_id,
replicate=replicate,
seed=seed,
agent_spec_ref=agent_id,
problem_spec_ref=problem_id,
execution_metadata={
"agent_id": agent_id,
"problem_id": problem_id,
"condition_fingerprint": condition.metadata.get("fingerprint"),
},
)
)
if study.run_budget.max_runs is not None:
return run_specs[: study.run_budget.max_runs]
return run_specs
def _resolve_agent_ids(study: Study, condition: Condition) -> tuple[str, ...]:
"""Resolve agent IDs for one condition."""
for key in ("agent_id", "agent", "agent_spec"):
if key in condition.factor_assignments:
return (str(condition.factor_assignments[key]),)
if study.agent_specs:
return tuple(study.agent_specs)
return ("default-agent",)
def _resolve_problem_ids(study: Study, condition: Condition) -> tuple[str, ...]:
"""Resolve problem IDs for one condition."""
for key in ("problem_id", "problem"):
if key in condition.factor_assignments:
return (str(condition.factor_assignments[key]),)
if study.problem_ids:
return tuple(study.problem_ids)
return ("default-problem",)
def _execute_single_run(
*,
run_spec: RunSpec,
condition: Condition,
agent_factories: Mapping[str, Callable[[Condition], Any]] | None,
problem_registry: Mapping[str, Any] | None,
) -> RunResult:
"""Execute one run spec with failure isolation."""
started_at = utc_now_iso()
start_time = time.perf_counter()
try:
with reproducible_seed(run_spec.seed):
problem_packet = resolve_problem(
run_spec.problem_spec_ref,
registry=problem_registry,
)
run_spec.execution_metadata["problem_family"] = problem_packet.family
agent_execution = execute_agent(
agent_spec_ref=run_spec.agent_spec_ref,
run_spec=run_spec,
condition=condition,
problem_packet=problem_packet,
factories=agent_factories,
)
evaluation_rows = evaluate_problem(problem_packet, agent_execution.output)
for row in evaluation_rows:
row["run_id"] = run_spec.run_id
latency_s = time.perf_counter() - start_time
cost_usd = float(agent_execution.metrics.get("cost_usd", 0.0) or 0.0)
metrics = compose_metrics(
agent_metrics=agent_execution.metrics,
evaluation_rows=evaluation_rows,
observations=agent_execution.events,
latency_s=latency_s,
cost_usd=cost_usd,
)
provenance_info = {
"agent_id": run_spec.execution_metadata.get("agent_id"),
"problem_id": run_spec.problem_id,
"problem_family": problem_packet.family,
"model_name": agent_execution.metadata.get("model_name"),
"execution_metadata": stable_json_dumps(run_spec.execution_metadata),
}
return RunResult(
run_id=run_spec.run_id,
status=RunStatus.SUCCESS,
outputs=agent_execution.output,
metrics=metrics,
evaluator_outputs=evaluation_rows,
cost=cost_usd,
latency=latency_s,
trace_refs=list(agent_execution.trace_refs),
artifact_refs=[],
error_info=None,
provenance_info=provenance_info,
observations=list(agent_execution.events),
run_spec=run_spec,
started_at=started_at,
ended_at=utc_now_iso(),
)
except Exception as exc:
latency_s = time.perf_counter() - start_time
return RunResult(
run_id=run_spec.run_id,
status=RunStatus.FAILED,
outputs={},
metrics={"latency_s": latency_s},
evaluator_outputs=[],
cost=0.0,
latency=latency_s,
trace_refs=[],
artifact_refs=[],
error_info=f"{type(exc).__name__}: {exc}",
provenance_info={"agent_id": run_spec.execution_metadata.get("agent_id")},
observations=[],
run_spec=run_spec,
started_at=started_at,
ended_at=utc_now_iso(),
)