Source code for design_research_experiments.study

"""Top-level study models and validation helpers."""

from __future__ import annotations

from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, cast

from .conditions import Constraint, Factor, FactorKind, Level
from .hypotheses import (
    AnalysisPlan,
    Hypothesis,
    OutcomeSpec,
    coerce_analysis_plan,
    coerce_hypothesis,
    coerce_outcome,
    validate_hypothesis_bindings,
)
from .io import json_io, yaml_io
from .schemas import (
    ProvenanceMetadata,
    RunBudget,
    RunStatus,
    SeedPolicy,
    ValidationError,
    to_jsonable,
)


[docs] @dataclass(slots=True) class Block: """Blocking structure for design-of-experiments materialization.""" name: str levels: tuple[Any, ...] metadata: dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: """Validate block content.""" if not self.name.strip(): raise ValidationError("Block.name must be non-empty.") if not self.levels: raise ValidationError(f"Block '{self.name}' must contain at least one level.")
[docs] @dataclass(slots=True) class RunSpec: """One executable run specification.""" run_id: str study_id: str condition_id: str problem_id: str replicate: int seed: int agent_spec_ref: Any problem_spec_ref: Any execution_metadata: dict[str, Any] = field(default_factory=dict)
[docs] @dataclass(slots=True) class RunResult: """Normalized result bundle for one run.""" run_id: str status: RunStatus outputs: dict[str, Any] = field(default_factory=dict) metrics: dict[str, Any] = field(default_factory=dict) evaluator_outputs: list[dict[str, Any]] = field(default_factory=list) cost: float = 0.0 latency: float = 0.0 trace_refs: list[str] = field(default_factory=list) artifact_refs: list[str] = field(default_factory=list) error_info: str | None = None provenance_info: dict[str, Any] = field(default_factory=dict) observations: list[Any] = field(default_factory=list) run_spec: RunSpec | None = None started_at: str | None = None ended_at: str | None = None
[docs] @dataclass(slots=True) class Study: """Top-level experiment definition.""" study_id: str title: str description: str authors: tuple[str, ...] = () rationale: str = "" tags: tuple[str, ...] = () hypotheses: tuple[Hypothesis, ...] = () factors: tuple[Factor, ...] = () blocks: tuple[Block, ...] = () constraints: tuple[Constraint, ...] = () design_spec: dict[str, Any] = field(default_factory=lambda: {"kind": "full_factorial"}) outcomes: tuple[OutcomeSpec, ...] = () analysis_plans: tuple[AnalysisPlan, ...] = () run_budget: RunBudget = field(default_factory=RunBudget) seed_policy: SeedPolicy = field(default_factory=SeedPolicy) output_dir: Path | None = None provenance_metadata: dict[str, Any] = field(default_factory=dict) notes: str = "" problem_ids: tuple[str, ...] = () agent_specs: tuple[str, ...] = () primary_outcomes: tuple[str, ...] = () secondary_outcomes: tuple[str, ...] = () def __post_init__(self) -> None: """Validate basic study fields and normalize output directory.""" if not self.study_id.strip(): raise ValidationError("Study.study_id must be non-empty.") if not self.title.strip(): raise ValidationError("Study.title must be non-empty.") if self.output_dir is None: self.output_dir = Path("artifacts") / self.study_id
[docs] def to_dict(self) -> dict[str, Any]: """Serialize the study to a stable JSON/YAML-friendly mapping.""" return cast(dict[str, Any], to_jsonable(self))
[docs] def to_yaml(self, path: str | Path) -> Path: """Write the study definition to YAML.""" return yaml_io.write_yaml(Path(path), self.to_dict())
[docs] def to_json(self, path: str | Path) -> Path: """Write the study definition to JSON.""" return json_io.write_json(Path(path), self.to_dict())
[docs] @classmethod def from_dict(cls, payload: Mapping[str, Any]) -> Study: """Construct a study from a loose mapping payload.""" factors = tuple( _coerce_factor(factor) for factor in cast(Sequence[Any], payload.get("factors", ())) ) blocks = tuple( _coerce_block(block) for block in cast(Sequence[Any], payload.get("blocks", ())) ) constraints = tuple( _coerce_constraint(constraint) for constraint in cast(Sequence[Any], payload.get("constraints", ())) ) hypotheses = tuple( coerce_hypothesis(hypothesis) for hypothesis in cast(Sequence[Any], payload.get("hypotheses", ())) ) outcomes = tuple( coerce_outcome(outcome) for outcome in cast(Sequence[Any], payload.get("outcomes", ())) ) analysis_plans = tuple( coerce_analysis_plan(plan) for plan in cast(Sequence[Any], payload.get("analysis_plans", ())) ) run_budget_payload = payload.get("run_budget") run_budget = ( run_budget_payload if isinstance(run_budget_payload, RunBudget) else RunBudget(**cast(dict[str, Any], run_budget_payload or {})) ) seed_policy_payload = payload.get("seed_policy") seed_policy = ( seed_policy_payload if isinstance(seed_policy_payload, SeedPolicy) else SeedPolicy(**cast(dict[str, Any], seed_policy_payload or {})) ) output_dir_payload = payload.get("output_dir") output_dir = Path(str(output_dir_payload)) if output_dir_payload else None return cls( study_id=str(payload["study_id"]), title=str(payload.get("title", "")), description=str(payload.get("description", "")), authors=tuple(cast(Sequence[str], payload.get("authors", ()))), rationale=str(payload.get("rationale", "")), tags=tuple(cast(Sequence[str], payload.get("tags", ()))), hypotheses=hypotheses, factors=factors, blocks=blocks, constraints=constraints, design_spec=dict(cast(Mapping[str, Any], payload.get("design_spec", {}))), outcomes=outcomes, analysis_plans=analysis_plans, run_budget=run_budget, seed_policy=seed_policy, output_dir=output_dir, provenance_metadata=dict( cast(Mapping[str, Any], payload.get("provenance_metadata", {})) ), notes=str(payload.get("notes", "")), problem_ids=tuple(cast(Sequence[str], payload.get("problem_ids", ()))), agent_specs=tuple(cast(Sequence[str], payload.get("agent_specs", ()))), primary_outcomes=tuple(cast(Sequence[str], payload.get("primary_outcomes", ()))), secondary_outcomes=tuple(cast(Sequence[str], payload.get("secondary_outcomes", ()))), )
[docs] @classmethod def from_yaml(cls, path: str | Path) -> Study: """Load a study from YAML.""" return cls.from_dict(yaml_io.read_yaml(Path(path)))
[docs] @classmethod def from_json(cls, path: str | Path) -> Study: """Load a study from JSON.""" return cls.from_dict(json_io.read_json(Path(path)))
[docs] def validate_study(study: Study) -> list[str]: """Validate cross-object references and study consistency.""" errors: list[str] = [] factor_names = [factor.name for factor in study.factors] block_names = [block.name for block in study.blocks] outcome_names = [outcome.name for outcome in study.outcomes] analysis_plan_ids = [plan.analysis_plan_id for plan in study.analysis_plans] hypothesis_ids = [hypothesis.hypothesis_id for hypothesis in study.hypotheses] errors.extend(_duplicate_errors("factor", factor_names)) errors.extend(_duplicate_errors("block", block_names)) errors.extend(_duplicate_errors("outcome", outcome_names)) errors.extend(_duplicate_errors("analysis plan", analysis_plan_ids)) errors.extend(_duplicate_errors("hypothesis", hypothesis_ids)) outcome_name_set = set(outcome_names) for outcome_name in study.primary_outcomes: if outcome_name not in outcome_name_set: errors.append(f"Primary outcome '{outcome_name}' is not defined in outcomes.") for outcome_name in study.secondary_outcomes: if outcome_name not in outcome_name_set: errors.append(f"Secondary outcome '{outcome_name}' is not defined in outcomes.") errors.extend( validate_hypothesis_bindings( study.hypotheses, factor_names=factor_names, outcome_names=outcome_names, analysis_plan_ids=analysis_plan_ids, ) ) for analysis_plan in study.analysis_plans: for hypothesis_id in analysis_plan.hypothesis_ids: if hypothesis_id not in set(hypothesis_ids): errors.append( "Analysis plan " f"'{analysis_plan.analysis_plan_id}' references unknown hypothesis " f"'{hypothesis_id}'." ) for outcome_name in analysis_plan.outcomes: if outcome_name not in outcome_name_set: errors.append( f"Analysis plan '{analysis_plan.analysis_plan_id}' references unknown outcome " f"'{outcome_name}'." ) if not study.problem_ids: errors.append("Study.problem_ids must include at least one problem ID.") if study.run_budget.max_runs is not None: requested_runs = ( len(study.problem_ids) * max(1, len(study.agent_specs) or 1) * study.run_budget.replicates ) if requested_runs > study.run_budget.max_runs: errors.append( "Run budget max_runs is below the configured problem/agent/replicate plan." ) return errors
def load_study(path: str | Path) -> Study: """Load a study from YAML or JSON based on file extension.""" resolved = Path(path) suffix = resolved.suffix.lower() if suffix in {".yaml", ".yml"}: return Study.from_yaml(resolved) if suffix == ".json": return Study.from_json(resolved) raise ValidationError("Study file must end with .yaml/.yml or .json.") def _duplicate_errors(label: str, names: Sequence[str]) -> list[str]: """Return duplicate-name errors for one label class.""" seen: set[str] = set() errors: list[str] = [] for name in names: if name in seen: errors.append(f"Duplicate {label} name detected: '{name}'.") seen.add(name) return errors def _coerce_factor(value: Factor | Mapping[str, Any]) -> Factor: """Coerce a mapping payload into a `Factor` instance.""" if isinstance(value, Factor): return value levels = tuple( level if isinstance(level, Level) else Level( name=str(cast(Mapping[str, Any], level)["name"]), value=cast(Mapping[str, Any], level).get("value"), label=cast(str | None, cast(Mapping[str, Any], level).get("label")), metadata=dict( cast(Mapping[str, Any], cast(Mapping[str, Any], level).get("metadata", {})) ), ) for level in cast(Sequence[Any], value.get("levels", ())) ) return Factor( name=str(value["name"]), description=str(value.get("description", "")), kind=FactorKind(str(value.get("kind", FactorKind.MANIPULATED.value))), levels=levels, dtype=cast(str | None, value.get("dtype")), default=value.get("default"), metadata=dict(cast(Mapping[str, Any], value.get("metadata", {}))), ) def _coerce_block(value: Block | Mapping[str, Any]) -> Block: """Coerce a mapping payload into a `Block` instance.""" if isinstance(value, Block): return value return Block( name=str(value["name"]), levels=tuple(cast(Sequence[Any], value.get("levels", ()))), metadata=dict(cast(Mapping[str, Any], value.get("metadata", {}))), ) def _coerce_constraint(value: Constraint | Mapping[str, Any]) -> Constraint: """Coerce a mapping payload into a `Constraint` instance.""" if isinstance(value, Constraint): return value return Constraint( constraint_id=str(value["constraint_id"]), description=str(value.get("description", "")), expression=cast(str | None, value.get("expression")), callable_ref=cast(str | None, value.get("callable_ref")), severity=value.get("severity", "error"), ) def build_default_provenance() -> dict[str, Any]: """Capture a baseline provenance payload for study manifests.""" return cast( dict[str, Any], to_jsonable( ProvenanceMetadata.capture( package_names=( "design-research-experiments", "design-research-agents", "design-research-problems", "design-research-analysis", ) ) ), )