Source code for design_research_experiments.runners

"""Run orchestration engines and reproducible execution helpers."""

from __future__ import annotations

import importlib
import random
import sys
import time
from collections.abc import Mapping, Sequence
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Any

from tqdm.auto import tqdm

from .adapters.agents import AgentBinding, execute_agent
from .adapters.problems import evaluate_problem, resolve_problem
from .artifacts import (
    checkpoint_run_result,
    export_canonical_artifacts,
    load_checkpointed_run_results,
)
from .conditions import Condition
from .designs import build_design
from .metrics import compose_metrics
from .schemas import RunStatus, ValidationError, hash_identifier, stable_json_dumps, utc_now_iso
from .study import RunResult, RunSpec, Study, validate_study


@dataclass(slots=True)
class DryRunReport:
    """Dry-run validation report for a planned execution."""

    errors: list[str]
    planned_runs: int


class _NoOpProgressBar:
    """Progress-bar shim used when visual progress is disabled."""

    def update(self, _steps: int = 1) -> None:
        """Ignore progress updates."""

    def set_postfix(
        self,
        _ordered_dict: Mapping[str, int] | None = None,
        *,
        refresh: bool = True,
    ) -> None:
        """Ignore postfix updates."""
        del refresh

    def close(self) -> None:
        """Ignore close calls."""


@dataclass(slots=True)
class _RunProgress:
    """Centralized progress adapter for run execution."""

    total: int
    initial: int
    success: int
    failed: int
    _bar: Any

    def __post_init__(self) -> None:
        """Synchronize initial success and failure counts."""
        self._sync_postfix()

    def record_result(self, result: RunResult) -> None:
        """Advance progress after one completed run."""
        if result.status == RunStatus.SUCCESS:
            self.success += 1
        elif result.status == RunStatus.FAILED:
            self.failed += 1
        self._bar.update(1)
        self._sync_postfix()

    def close(self) -> None:
        """Close the underlying progress bar."""
        self._bar.close()

    def _sync_postfix(self) -> None:
        """Refresh success and failure counters."""
        self._bar.set_postfix({"success": self.success, "failed": self.failed})


class SerialRunner:
    """Serial orchestration runner."""

    def run(
        self,
        *,
        run_specs: Sequence[RunSpec],
        condition_by_id: Mapping[str, Condition],
        agent_bindings: Mapping[str, AgentBinding] | None,
        problem_registry: Mapping[str, Any] | None,
        output_dir: Path,
        checkpoint: bool,
        fail_fast: bool,
        progress: _RunProgress,
    ) -> list[RunResult]:
        """Execute all run specs one-by-one."""
        results: list[RunResult] = []
        for run_spec in run_specs:
            condition = condition_by_id[run_spec.condition_id]
            result = _execute_single_run(
                run_spec=run_spec,
                condition=condition,
                agent_bindings=agent_bindings,
                problem_registry=problem_registry,
            )
            results.append(result)
            if checkpoint:
                checkpoint_run_result(result, output_dir=output_dir)
            progress.record_result(result)
            if fail_fast and result.status == RunStatus.FAILED:
                break
        return results


class LocalParallelRunner:
    """Thread-based local parallel orchestration runner."""

    def run(
        self,
        *,
        run_specs: Sequence[RunSpec],
        condition_by_id: Mapping[str, Condition],
        agent_bindings: Mapping[str, AgentBinding] | None,
        problem_registry: Mapping[str, Any] | None,
        output_dir: Path,
        checkpoint: bool,
        fail_fast: bool,
        max_workers: int,
        progress: _RunProgress,
    ) -> list[RunResult]:
        """Execute run specs in a local thread pool."""
        results: list[RunResult] = []

        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            future_by_run_id: dict[Future[RunResult], str] = {}
            for run_spec in run_specs:
                condition = condition_by_id[run_spec.condition_id]
                future = executor.submit(
                    _execute_single_run,
                    run_spec=run_spec,
                    condition=condition,
                    agent_bindings=agent_bindings,
                    problem_registry=problem_registry,
                )
                future_by_run_id[future] = run_spec.run_id

            for future in as_completed(future_by_run_id):
                result = future.result()
                results.append(result)
                if checkpoint:
                    checkpoint_run_result(result, output_dir=output_dir)
                progress.record_result(result)
                if fail_fast and result.status == RunStatus.FAILED:
                    for pending_future in future_by_run_id:
                        if pending_future.done():
                            continue
                        pending_future.cancel()
                    break

        return results


[docs] def run_study( study: Study, *, conditions: Sequence[Condition] | None = None, agent_bindings: Mapping[str, AgentBinding] | None = None, problem_registry: Mapping[str, Any] | None = None, parallelism: int | None = None, dry_run: bool = False, resume: bool = False, checkpoint: bool = True, include_sqlite: bool = False, show_progress: bool | None = None, ) -> list[RunResult]: """Run a study end-to-end and export canonical artifacts.""" resolved_conditions = list(conditions) if conditions is not None else build_design(study) report = dry_run_validate(study, conditions=resolved_conditions) if report.errors: raise ValidationError("\n".join(report.errors)) if dry_run: return [] output_dir = Path(study.output_dir or Path("artifacts") / study.study_id) output_dir.mkdir(parents=True, exist_ok=True) all_run_specs = _build_run_specs(study=study, conditions=resolved_conditions) planned_run_ids = {run_spec.run_id for run_spec in all_run_specs} condition_by_id = {condition.condition_id: condition for condition in resolved_conditions} existing_results: list[RunResult] = [] completed_run_ids: set[str] = set() if resume: checkpointed = load_checkpointed_run_results(output_dir) existing_results = list(checkpointed.values()) completed_run_ids = set(checkpointed) pending_run_specs = [ run_spec for run_spec in all_run_specs if run_spec.run_id not in completed_run_ids ] existing_progress_results = [ result for result in existing_results if result.run_id in planned_run_ids ] resolved_parallelism = parallelism if parallelism is not None else study.run_budget.parallelism if resolved_parallelism < 1: raise ValidationError("parallelism must be >= 1.") progress = _create_run_progress( study_id=study.study_id, total=len(all_run_specs), initial=len(all_run_specs) - len(pending_run_specs), run_results=existing_progress_results, pending_runs=len(pending_run_specs), show_progress=show_progress, ) try: if resolved_parallelism == 1: serial_runner = SerialRunner() new_results = serial_runner.run( run_specs=pending_run_specs, condition_by_id=condition_by_id, agent_bindings=agent_bindings, problem_registry=problem_registry, output_dir=output_dir, checkpoint=checkpoint, fail_fast=study.run_budget.fail_fast, progress=progress, ) else: parallel_runner = LocalParallelRunner() new_results = parallel_runner.run( run_specs=pending_run_specs, condition_by_id=condition_by_id, agent_bindings=agent_bindings, problem_registry=problem_registry, output_dir=output_dir, checkpoint=checkpoint, fail_fast=study.run_budget.fail_fast, max_workers=resolved_parallelism, progress=progress, ) all_results = existing_results + new_results export_canonical_artifacts( study=study, conditions=resolved_conditions, run_results=all_results, output_dir=output_dir, include_sqlite=include_sqlite, ) return all_results finally: progress.close()
[docs] def resume_study( study: Study, *, conditions: Sequence[Condition] | None = None, agent_bindings: Mapping[str, AgentBinding] | None = None, problem_registry: Mapping[str, Any] | None = None, parallelism: int | None = None, checkpoint: bool = True, include_sqlite: bool = False, show_progress: bool | None = None, ) -> list[RunResult]: """Resume a study from checkpointed run results.""" return run_study( study, conditions=conditions, agent_bindings=agent_bindings, problem_registry=problem_registry, parallelism=parallelism, dry_run=False, resume=True, checkpoint=checkpoint, include_sqlite=include_sqlite, show_progress=show_progress, )
def dry_run_validate(study: Study, *, conditions: Sequence[Condition]) -> DryRunReport: """Validate run inputs before launching execution.""" errors = list(validate_study(study)) admissible_conditions = [condition for condition in conditions if condition.admissible] if not admissible_conditions: errors.append("No admissible conditions are available to execute.") planned_runs = len(_build_run_specs(study=study, conditions=conditions)) if planned_runs < 1: errors.append("No run specifications were generated.") if study.run_budget.max_runs is not None and planned_runs > study.run_budget.max_runs: errors.append( "Planned runs " f"({planned_runs}) exceed run budget max_runs ({study.run_budget.max_runs})." ) return DryRunReport(errors=errors, planned_runs=planned_runs) def reproducible_seed(seed: int) -> Any: """Context manager preserving and restoring random RNG state.""" return _reproducible_seed(seed) @contextmanager def _reproducible_seed(seed: int) -> Any: """Temporarily set deterministic RNG seeds for one run.""" previous_state = random.getstate() random.seed(seed) numpy_state: tuple[Any, Any] | None = None try: numpy_module = importlib.import_module("numpy") numpy_random = numpy_module.random get_state = numpy_random.get_state set_state = numpy_random.set_state seed_function = numpy_random.seed numpy_state = (get_state(), set_state) seed_function(seed) except Exception: numpy_state = None try: yield finally: random.setstate(previous_state) if numpy_state is not None: state, set_state = numpy_state set_state(state) def _build_run_specs(study: Study, conditions: Sequence[Condition]) -> list[RunSpec]: """Materialize deterministic run specs for all admissible conditions.""" admissible_conditions = [condition for condition in conditions if condition.admissible] run_specs: list[RunSpec] = [] for condition in admissible_conditions: agent_ids = _resolve_agent_ids(study=study, condition=condition) problem_ids = _resolve_problem_ids(study=study, condition=condition) for replicate in range(1, study.run_budget.replicates + 1): for agent_id in agent_ids: for problem_id in problem_ids: run_id = hash_identifier( "run", { "study_id": study.study_id, "condition_id": condition.condition_id, "replicate": replicate, "agent_id": agent_id, "problem_id": problem_id, }, ) seed = study.seed_policy.derive_seed( study_id=study.study_id, condition_id=condition.condition_id, replicate=replicate, salt=f"{agent_id}:{problem_id}", ) run_specs.append( RunSpec( run_id=run_id, study_id=study.study_id, condition_id=condition.condition_id, problem_id=problem_id, replicate=replicate, seed=seed, agent_spec_ref=agent_id, problem_spec_ref=problem_id, execution_metadata={ "agent_id": agent_id, "problem_id": problem_id, "condition_fingerprint": condition.metadata.get("fingerprint"), }, ) ) if study.run_budget.max_runs is not None: return run_specs[: study.run_budget.max_runs] return run_specs def _resolve_agent_ids(study: Study, condition: Condition) -> tuple[str, ...]: """Resolve agent IDs for one condition.""" for key in ("agent_id", "agent", "agent_spec"): if key in condition.factor_assignments: return (str(condition.factor_assignments[key]),) if study.agent_specs: return tuple(study.agent_specs) return ("default-agent",) def _resolve_problem_ids(study: Study, condition: Condition) -> tuple[str, ...]: """Resolve problem IDs for one condition.""" for key in ("problem_id", "problem"): if key in condition.factor_assignments: return (str(condition.factor_assignments[key]),) if study.problem_ids: return tuple(study.problem_ids) return ("default-problem",) def _create_run_progress( *, study_id: str, total: int, initial: int, run_results: Sequence[RunResult], pending_runs: int, show_progress: bool | None, ) -> _RunProgress: """Build the run-progress adapter for one execution.""" success = sum(1 for result in run_results if result.status == RunStatus.SUCCESS) failed = sum(1 for result in run_results if result.status == RunStatus.FAILED) if _should_render_progress(show_progress=show_progress, pending_runs=pending_runs): bar: Any = tqdm( total=total, initial=initial, desc=study_id, unit="run", dynamic_ncols=True, leave=True, file=sys.stderr, ) else: bar = _NoOpProgressBar() return _RunProgress(total=total, initial=initial, success=success, failed=failed, _bar=bar) def _should_render_progress(*, show_progress: bool | None, pending_runs: int) -> bool: """Decide whether to create a visible progress bar.""" if pending_runs < 1: return False if show_progress is not None: return show_progress stderr = sys.stderr isatty = getattr(stderr, "isatty", None) return bool(callable(isatty) and isatty()) def _execute_single_run( *, run_spec: RunSpec, condition: Condition, agent_bindings: Mapping[str, AgentBinding] | None, problem_registry: Mapping[str, Any] | None, ) -> RunResult: """Execute one run spec with failure isolation.""" started_at = utc_now_iso() start_time = time.perf_counter() try: with reproducible_seed(run_spec.seed): problem_packet = resolve_problem( run_spec.problem_spec_ref, registry=problem_registry, ) run_spec.execution_metadata["problem_family"] = problem_packet.family agent_execution = execute_agent( agent_spec_ref=run_spec.agent_spec_ref, run_spec=run_spec, condition=condition, problem_packet=problem_packet, agent_bindings=agent_bindings, ) for key in ( "model_name", "model_provider", "agent_kind", "pattern_name", "request_id", "trace_dir", "trace_path", ): value = agent_execution.metadata.get(key) if value in (None, ""): continue run_spec.execution_metadata[key] = value evaluation_rows = evaluate_problem(problem_packet, agent_execution.output) for row in evaluation_rows: row["run_id"] = run_spec.run_id latency_s = time.perf_counter() - start_time cost_usd = float(agent_execution.metrics.get("cost_usd", 0.0) or 0.0) metrics = compose_metrics( agent_metrics=agent_execution.metrics, evaluation_rows=evaluation_rows, observations=agent_execution.events, latency_s=latency_s, cost_usd=cost_usd, ) provenance_info = { "agent_id": run_spec.execution_metadata.get("agent_id"), "problem_id": run_spec.problem_id, "problem_family": problem_packet.family, "model_name": agent_execution.metadata.get("model_name"), "model_provider": agent_execution.metadata.get("model_provider"), "request_id": agent_execution.metadata.get("request_id"), "trace_dir": agent_execution.metadata.get("trace_dir"), "trace_path": agent_execution.metadata.get("trace_path"), "execution_metadata": stable_json_dumps(run_spec.execution_metadata), } return RunResult( run_id=run_spec.run_id, status=RunStatus.SUCCESS, outputs=agent_execution.output, metrics=metrics, evaluator_outputs=evaluation_rows, cost=cost_usd, latency=latency_s, trace_refs=list(agent_execution.trace_refs), artifact_refs=[], error_info=None, provenance_info=provenance_info, observations=list(agent_execution.events), run_spec=run_spec, started_at=started_at, ended_at=utc_now_iso(), ) except Exception as exc: latency_s = time.perf_counter() - start_time return RunResult( run_id=run_spec.run_id, status=RunStatus.FAILED, outputs={}, metrics={"latency_s": latency_s}, evaluator_outputs=[], cost=0.0, latency=latency_s, trace_refs=[], artifact_refs=[], error_info=f"{type(exc).__name__}: {exc}", provenance_info={"agent_id": run_spec.execution_metadata.get("agent_id")}, observations=[], run_spec=run_spec, started_at=started_at, ended_at=utc_now_iso(), )