"""Shared schema primitives and serialization helpers for experiment workflows."""
from __future__ import annotations
import hashlib
import importlib
import importlib.metadata
import json
import platform
import socket
import subprocess
import sys
from collections.abc import Callable, Iterable, Mapping
from dataclasses import asdict, dataclass, field, is_dataclass
from datetime import UTC, datetime
from enum import Enum, StrEnum
from pathlib import Path
from typing import Any, cast
SCHEMA_VERSION = "0.1.0"
class ValidationError(ValueError):
"""Raised when a schema object fails validation."""
class ConstraintSeverity(StrEnum):
"""Severity level for admissibility constraints."""
ERROR = "error"
WARNING = "warning"
class RunStatus(StrEnum):
"""Lifecycle status of one run."""
PENDING = "pending"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
SKIPPED = "skipped"
class ObservationLevel(StrEnum):
"""Granularity level for an observation row."""
STUDY = "study"
RUN = "run"
TRIAL = "trial"
STEP = "step"
TOOL_CALL = "tool-call"
EVALUATION = "evaluation"
[docs]
@dataclass(slots=True)
class SeedPolicy:
"""Deterministic seed policy for run generation.
Args:
base_seed: Study-wide base seed.
strategy: Seed derivation strategy name.
per_run_offset: Numeric offset mixed into each run seed.
"""
base_seed: int = 0
strategy: str = "condition_replicate"
per_run_offset: int = 9973
[docs]
def derive_seed(self, study_id: str, condition_id: str, replicate: int, salt: str = "") -> int:
"""Derive one deterministic per-run seed."""
payload = {
"base_seed": self.base_seed,
"strategy": self.strategy,
"per_run_offset": self.per_run_offset,
"study_id": study_id,
"condition_id": condition_id,
"replicate": replicate,
"salt": salt,
}
digest = hashlib.sha256(stable_json_dumps(payload).encode("utf-8")).hexdigest()
return int(digest[:16], 16) % (2**31 - 1)
[docs]
@dataclass(slots=True)
class RunBudget:
"""Execution budget controls for one study.
Args:
replicates: Number of run replicates for each unit.
max_runs: Optional upper bound on total runs.
parallelism: Local worker count.
fail_fast: Stop after first failure when `True`.
"""
replicates: int = 1
max_runs: int | None = None
parallelism: int = 1
fail_fast: bool = False
def __post_init__(self) -> None:
"""Validate budget shape."""
if self.replicates < 1:
raise ValidationError("RunBudget.replicates must be >= 1.")
if self.parallelism < 1:
raise ValidationError("RunBudget.parallelism must be >= 1.")
if self.max_runs is not None and self.max_runs < 1:
raise ValidationError("RunBudget.max_runs must be >= 1 when provided.")
@dataclass(slots=True)
class ProvenanceMetadata:
"""Captured runtime provenance for reproducibility.
Args:
captured_at: UTC timestamp when provenance was captured.
host: Hostname where execution happened.
platform: Platform descriptor.
python_version: Runtime Python version.
package_versions: Resolved package versions.
git_sha: Git commit SHA when available.
extra: Additional caller-provided metadata.
"""
captured_at: str
host: str
platform: str
python_version: str
package_versions: dict[str, str] = field(default_factory=dict)
git_sha: str | None = None
extra: dict[str, Any] = field(default_factory=dict)
@classmethod
def capture(
cls,
package_names: Iterable[str] = (),
*,
cwd: Path | None = None,
extra: Mapping[str, Any] | None = None,
) -> ProvenanceMetadata:
"""Capture reproducibility metadata from the local environment."""
versions: dict[str, str] = {}
for package_name in package_names:
try:
versions[package_name] = importlib.metadata.version(package_name)
except importlib.metadata.PackageNotFoundError:
versions[package_name] = "not-installed"
return cls(
captured_at=utc_now_iso(),
host=socket.gethostname(),
platform=platform.platform(),
python_version=sys.version.split()[0],
package_versions=versions,
git_sha=resolve_git_sha(cwd=cwd),
extra=dict(extra or {}),
)
@dataclass(slots=True)
class Observation:
"""Normalized process trace observation.
Args:
timestamp: UTC timestamp string.
record_id: Stable row identifier.
text: Optional human-readable event text.
session_id: Session or conversation identifier.
actor_id: Actor identifier.
event_type: Event kind label.
meta_json: Structured event metadata.
level: Observation granularity.
study_id: Optional study ID.
run_id: Optional run ID.
condition_id: Optional condition ID.
trial_id: Optional trial ID.
step_id: Optional step ID.
tool_name: Optional tool name.
evaluation_id: Optional evaluation record ID.
"""
timestamp: str
record_id: str
text: str
session_id: str
actor_id: str
event_type: str
meta_json: dict[str, Any] = field(default_factory=dict)
level: ObservationLevel = ObservationLevel.STEP
study_id: str | None = None
run_id: str | None = None
condition_id: str | None = None
trial_id: str | None = None
step_id: str | None = None
tool_name: str | None = None
evaluation_id: str | None = None
def to_row(self) -> dict[str, Any]:
"""Convert the observation to one export row."""
return {
"timestamp": self.timestamp,
"record_id": self.record_id,
"text": self.text,
"session_id": self.session_id,
"actor_id": self.actor_id,
"event_type": self.event_type,
"meta_json": stable_json_dumps(self.meta_json),
"level": self.level.value,
"study_id": self.study_id,
"run_id": self.run_id,
"condition_id": self.condition_id,
"trial_id": self.trial_id,
"step_id": self.step_id,
"tool_name": self.tool_name,
"evaluation_id": self.evaluation_id,
}
def utc_now_iso() -> str:
"""Return an ISO-8601 UTC timestamp without microseconds."""
return datetime.now(UTC).replace(microsecond=0).isoformat()
def resolve_git_sha(cwd: Path | None = None) -> str | None:
"""Resolve the current git SHA if the working directory is a git checkout."""
try:
output = subprocess.check_output(
["git", "rev-parse", "HEAD"],
cwd=str(cwd) if cwd is not None else None,
stderr=subprocess.DEVNULL,
text=True,
).strip()
except (FileNotFoundError, subprocess.CalledProcessError):
return None
return output or None
def stable_json_dumps(data: Any) -> str:
"""Serialize arbitrary data deterministically for IDs and manifests."""
return json.dumps(to_jsonable(data), sort_keys=True, separators=(",", ":"), ensure_ascii=True)
def to_jsonable(value: Any) -> Any:
"""Recursively convert a Python value to a JSON-serializable structure."""
if value is None:
return None
if isinstance(value, Enum):
return value.value
if isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, Path):
return str(value)
if isinstance(value, dict):
return {str(key): to_jsonable(item) for key, item in value.items()}
if isinstance(value, (list, tuple, set, frozenset)):
return [to_jsonable(item) for item in value]
if is_dataclass(value) and not isinstance(value, type):
return to_jsonable(asdict(value))
return str(value)
def hash_identifier(prefix: str, payload: Mapping[str, Any], *, length: int = 12) -> str:
"""Build a deterministic short identifier from a stable payload hash."""
digest = hashlib.sha256(stable_json_dumps(payload).encode("utf-8")).hexdigest()
return f"{prefix}-{digest[:length]}"
def load_callable(reference: str) -> Callable[..., Any]:
"""Load a callable from a `module:attribute` reference string."""
if ":" not in reference:
raise ValidationError(
"Callable reference must use the format 'module.submodule:callable_name'."
)
module_name, attribute_name = reference.split(":", maxsplit=1)
module = importlib.import_module(module_name)
loaded = getattr(module, attribute_name)
if not callable(loaded):
raise ValidationError(f"Reference '{reference}' does not resolve to a callable.")
return cast(Callable[..., Any], loaded)