Source code for design_research_experiments.runners

"""Run orchestration engines and reproducible execution helpers."""

from __future__ import annotations

import importlib
import random
import time
from collections.abc import Callable, Mapping, Sequence
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Any

from .adapters.agents import execute_agent
from .adapters.problems import evaluate_problem, resolve_problem
from .artifacts import (
    checkpoint_run_result,
    export_canonical_artifacts,
    load_checkpointed_run_results,
)
from .conditions import Condition
from .designs import build_design
from .metrics import compose_metrics
from .schemas import RunStatus, ValidationError, hash_identifier, stable_json_dumps, utc_now_iso
from .study import RunResult, RunSpec, Study, validate_study


@dataclass(slots=True)
class DryRunReport:
    """Dry-run validation report for a planned execution."""

    errors: list[str]
    planned_runs: int


class SerialRunner:
    """Serial orchestration runner."""

    def run(
        self,
        *,
        run_specs: Sequence[RunSpec],
        condition_by_id: Mapping[str, Condition],
        agent_factories: Mapping[str, Callable[[Condition], Any]] | None,
        problem_registry: Mapping[str, Any] | None,
        output_dir: Path,
        checkpoint: bool,
        fail_fast: bool,
    ) -> list[RunResult]:
        """Execute all run specs one-by-one."""
        results: list[RunResult] = []
        for run_spec in run_specs:
            condition = condition_by_id[run_spec.condition_id]
            result = _execute_single_run(
                run_spec=run_spec,
                condition=condition,
                agent_factories=agent_factories,
                problem_registry=problem_registry,
            )
            results.append(result)
            if checkpoint:
                checkpoint_run_result(result, output_dir=output_dir)
            if fail_fast and result.status == RunStatus.FAILED:
                break
        return results


class LocalParallelRunner:
    """Thread-based local parallel orchestration runner."""

    def run(
        self,
        *,
        run_specs: Sequence[RunSpec],
        condition_by_id: Mapping[str, Condition],
        agent_factories: Mapping[str, Callable[[Condition], Any]] | None,
        problem_registry: Mapping[str, Any] | None,
        output_dir: Path,
        checkpoint: bool,
        fail_fast: bool,
        max_workers: int,
    ) -> list[RunResult]:
        """Execute run specs in a local thread pool."""
        results: list[RunResult] = []

        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            future_by_run_id: dict[Future[RunResult], str] = {}
            for run_spec in run_specs:
                condition = condition_by_id[run_spec.condition_id]
                future = executor.submit(
                    _execute_single_run,
                    run_spec=run_spec,
                    condition=condition,
                    agent_factories=agent_factories,
                    problem_registry=problem_registry,
                )
                future_by_run_id[future] = run_spec.run_id

            for future in as_completed(future_by_run_id):
                result = future.result()
                results.append(result)
                if checkpoint:
                    checkpoint_run_result(result, output_dir=output_dir)
                if fail_fast and result.status == RunStatus.FAILED:
                    for pending_future in future_by_run_id:
                        if pending_future.done():
                            continue
                        pending_future.cancel()
                    break

        return results


[docs] def run_study( study: Study, *, conditions: Sequence[Condition] | None = None, agent_factories: Mapping[str, Callable[[Condition], Any]] | None = None, problem_registry: Mapping[str, Any] | None = None, parallelism: int | None = None, dry_run: bool = False, resume: bool = False, checkpoint: bool = True, include_sqlite: bool = False, ) -> list[RunResult]: """Run a study end-to-end and export canonical artifacts.""" resolved_conditions = list(conditions) if conditions is not None else build_design(study) report = dry_run_validate(study, conditions=resolved_conditions) if report.errors: raise ValidationError("\n".join(report.errors)) if dry_run: return [] output_dir = Path(study.output_dir or Path("artifacts") / study.study_id) output_dir.mkdir(parents=True, exist_ok=True) all_run_specs = _build_run_specs(study=study, conditions=resolved_conditions) condition_by_id = {condition.condition_id: condition for condition in resolved_conditions} existing_results: list[RunResult] = [] completed_run_ids: set[str] = set() if resume: checkpointed = load_checkpointed_run_results(output_dir) existing_results = list(checkpointed.values()) completed_run_ids = set(checkpointed) pending_run_specs = [ run_spec for run_spec in all_run_specs if run_spec.run_id not in completed_run_ids ] resolved_parallelism = parallelism if parallelism is not None else study.run_budget.parallelism if resolved_parallelism < 1: raise ValidationError("parallelism must be >= 1.") if resolved_parallelism == 1: serial_runner = SerialRunner() new_results = serial_runner.run( run_specs=pending_run_specs, condition_by_id=condition_by_id, agent_factories=agent_factories, problem_registry=problem_registry, output_dir=output_dir, checkpoint=checkpoint, fail_fast=study.run_budget.fail_fast, ) else: parallel_runner = LocalParallelRunner() new_results = parallel_runner.run( run_specs=pending_run_specs, condition_by_id=condition_by_id, agent_factories=agent_factories, problem_registry=problem_registry, output_dir=output_dir, checkpoint=checkpoint, fail_fast=study.run_budget.fail_fast, max_workers=resolved_parallelism, ) all_results = existing_results + new_results export_canonical_artifacts( study=study, conditions=resolved_conditions, run_results=all_results, output_dir=output_dir, include_sqlite=include_sqlite, ) return all_results
[docs] def resume_study( study: Study, *, conditions: Sequence[Condition] | None = None, agent_factories: Mapping[str, Callable[[Condition], Any]] | None = None, problem_registry: Mapping[str, Any] | None = None, parallelism: int | None = None, checkpoint: bool = True, include_sqlite: bool = False, ) -> list[RunResult]: """Resume a study from checkpointed run results.""" return run_study( study, conditions=conditions, agent_factories=agent_factories, problem_registry=problem_registry, parallelism=parallelism, dry_run=False, resume=True, checkpoint=checkpoint, include_sqlite=include_sqlite, )
def dry_run_validate(study: Study, *, conditions: Sequence[Condition]) -> DryRunReport: """Validate run inputs before launching execution.""" errors = list(validate_study(study)) admissible_conditions = [condition for condition in conditions if condition.admissible] if not admissible_conditions: errors.append("No admissible conditions are available to execute.") planned_runs = len(_build_run_specs(study=study, conditions=conditions)) if planned_runs < 1: errors.append("No run specifications were generated.") if study.run_budget.max_runs is not None and planned_runs > study.run_budget.max_runs: errors.append( "Planned runs " f"({planned_runs}) exceed run budget max_runs ({study.run_budget.max_runs})." ) return DryRunReport(errors=errors, planned_runs=planned_runs) def reproducible_seed(seed: int) -> Any: """Context manager preserving and restoring random RNG state.""" return _reproducible_seed(seed) @contextmanager def _reproducible_seed(seed: int) -> Any: """Temporarily set deterministic RNG seeds for one run.""" previous_state = random.getstate() random.seed(seed) numpy_state: tuple[Any, Any] | None = None try: numpy_module = importlib.import_module("numpy") numpy_random = numpy_module.random get_state = numpy_random.get_state set_state = numpy_random.set_state seed_function = numpy_random.seed numpy_state = (get_state(), set_state) seed_function(seed) except Exception: numpy_state = None try: yield finally: random.setstate(previous_state) if numpy_state is not None: state, set_state = numpy_state set_state(state) def _build_run_specs(study: Study, conditions: Sequence[Condition]) -> list[RunSpec]: """Materialize deterministic run specs for all admissible conditions.""" admissible_conditions = [condition for condition in conditions if condition.admissible] run_specs: list[RunSpec] = [] for condition in admissible_conditions: agent_ids = _resolve_agent_ids(study=study, condition=condition) problem_ids = _resolve_problem_ids(study=study, condition=condition) for replicate in range(1, study.run_budget.replicates + 1): for agent_id in agent_ids: for problem_id in problem_ids: run_id = hash_identifier( "run", { "study_id": study.study_id, "condition_id": condition.condition_id, "replicate": replicate, "agent_id": agent_id, "problem_id": problem_id, }, ) seed = study.seed_policy.derive_seed( study_id=study.study_id, condition_id=condition.condition_id, replicate=replicate, salt=f"{agent_id}:{problem_id}", ) run_specs.append( RunSpec( run_id=run_id, study_id=study.study_id, condition_id=condition.condition_id, problem_id=problem_id, replicate=replicate, seed=seed, agent_spec_ref=agent_id, problem_spec_ref=problem_id, execution_metadata={ "agent_id": agent_id, "problem_id": problem_id, "condition_fingerprint": condition.metadata.get("fingerprint"), }, ) ) if study.run_budget.max_runs is not None: return run_specs[: study.run_budget.max_runs] return run_specs def _resolve_agent_ids(study: Study, condition: Condition) -> tuple[str, ...]: """Resolve agent IDs for one condition.""" for key in ("agent_id", "agent", "agent_spec"): if key in condition.factor_assignments: return (str(condition.factor_assignments[key]),) if study.agent_specs: return tuple(study.agent_specs) return ("default-agent",) def _resolve_problem_ids(study: Study, condition: Condition) -> tuple[str, ...]: """Resolve problem IDs for one condition.""" for key in ("problem_id", "problem"): if key in condition.factor_assignments: return (str(condition.factor_assignments[key]),) if study.problem_ids: return tuple(study.problem_ids) return ("default-problem",) def _execute_single_run( *, run_spec: RunSpec, condition: Condition, agent_factories: Mapping[str, Callable[[Condition], Any]] | None, problem_registry: Mapping[str, Any] | None, ) -> RunResult: """Execute one run spec with failure isolation.""" started_at = utc_now_iso() start_time = time.perf_counter() try: with reproducible_seed(run_spec.seed): problem_packet = resolve_problem( run_spec.problem_spec_ref, registry=problem_registry, ) run_spec.execution_metadata["problem_family"] = problem_packet.family agent_execution = execute_agent( agent_spec_ref=run_spec.agent_spec_ref, run_spec=run_spec, condition=condition, problem_packet=problem_packet, factories=agent_factories, ) evaluation_rows = evaluate_problem(problem_packet, agent_execution.output) for row in evaluation_rows: row["run_id"] = run_spec.run_id latency_s = time.perf_counter() - start_time cost_usd = float(agent_execution.metrics.get("cost_usd", 0.0) or 0.0) metrics = compose_metrics( agent_metrics=agent_execution.metrics, evaluation_rows=evaluation_rows, observations=agent_execution.events, latency_s=latency_s, cost_usd=cost_usd, ) provenance_info = { "agent_id": run_spec.execution_metadata.get("agent_id"), "problem_id": run_spec.problem_id, "problem_family": problem_packet.family, "model_name": agent_execution.metadata.get("model_name"), "execution_metadata": stable_json_dumps(run_spec.execution_metadata), } return RunResult( run_id=run_spec.run_id, status=RunStatus.SUCCESS, outputs=agent_execution.output, metrics=metrics, evaluator_outputs=evaluation_rows, cost=cost_usd, latency=latency_s, trace_refs=list(agent_execution.trace_refs), artifact_refs=[], error_info=None, provenance_info=provenance_info, observations=list(agent_execution.events), run_spec=run_spec, started_at=started_at, ended_at=utc_now_iso(), ) except Exception as exc: latency_s = time.perf_counter() - start_time return RunResult( run_id=run_spec.run_id, status=RunStatus.FAILED, outputs={}, metrics={"latency_s": latency_s}, evaluator_outputs=[], cost=0.0, latency=latency_s, trace_refs=[], artifact_refs=[], error_info=f"{type(exc).__name__}: {exc}", provenance_info={"agent_id": run_spec.execution_metadata.get("agent_id")}, observations=[], run_spec=run_spec, started_at=started_at, ended_at=utc_now_iso(), )