Source code for design_research_experiments.conditions

"""Factor, level, constraint, and condition materialization primitives."""

from __future__ import annotations

import ast
import itertools
import random
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from enum import StrEnum
from typing import TYPE_CHECKING, Any, cast

from .schemas import (
    ConstraintSeverity,
    ValidationError,
    hash_identifier,
    load_callable,
    stable_json_dumps,
)

if TYPE_CHECKING:
    from .study import Block, Study


[docs] class FactorKind(StrEnum): """Classification for an experimental factor.""" MANIPULATED = "manipulated" MEASURED = "measured" BLOCKED = "blocked" NUISANCE = "nuisance"
[docs] @dataclass(slots=True) class Level: """One admissible level/value for a factor. Args: name: Stable level identifier. value: Encoded level value. label: Optional display label. metadata: Optional metadata payload. """ name: str value: Any label: str | None = None metadata: dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: """Validate and normalize level content.""" if not self.name.strip(): raise ValidationError("Level.name must be non-empty.") if self.label is None: self.label = self.name
[docs] @dataclass(slots=True) class Factor: """Definition of one experimental factor. Args: name: Stable factor identifier. description: Human-readable description. kind: Factor type. levels: Allowed level set. dtype: Optional value type hint. default: Optional default level value. metadata: Optional metadata payload. """ name: str description: str kind: FactorKind = FactorKind.MANIPULATED levels: tuple[Level, ...] = () dtype: str | None = None default: Any = None metadata: dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: """Validate factor structure.""" if not self.name.strip(): raise ValidationError("Factor.name must be non-empty.") if not self.levels and self.default is None: raise ValidationError( f"Factor '{self.name}' must declare at least one level or a default value." ) level_names = [level.name for level in self.levels] if len(level_names) != len(set(level_names)): raise ValidationError(f"Factor '{self.name}' has duplicate level names.")
[docs] def iter_values(self) -> tuple[Any, ...]: """Return all admissible values for cartesian materialization.""" if self.levels: return tuple(level.value for level in self.levels) return (self.default,)
[docs] @dataclass(slots=True) class Constraint: """Admissibility rule over factor and block assignments. Args: constraint_id: Stable identifier for this constraint. description: Human-readable description. expression: Optional safe expression string. callable_ref: Optional `module:callable` reference. severity: Whether violation should fail or warn. """ constraint_id: str description: str expression: str | None = None callable_ref: str | None = None severity: ConstraintSeverity = ConstraintSeverity.ERROR def __post_init__(self) -> None: """Validate that a rule implementation is configured.""" if not self.constraint_id.strip(): raise ValidationError("Constraint.constraint_id must be non-empty.") if self.expression is None and self.callable_ref is None: raise ValidationError("Constraint must define either `expression` or `callable_ref`.")
[docs] def evaluate(self, factors: Mapping[str, Any], blocks: Mapping[str, Any]) -> bool: """Return whether this constraint passes for one assignment.""" checks: list[bool] = [] if self.expression is not None: context: dict[str, Any] = {"factors": dict(factors), "blocks": dict(blocks)} context.update(factors) context.update(blocks) checks.append(evaluate_constraint_expression(self.expression, context)) if self.callable_ref is not None: predicate = load_callable(self.callable_ref) checks.append(bool(predicate(dict(factors), dict(blocks)))) return all(checks)
[docs] @dataclass(slots=True) class Condition: """One realized treatment combination. Args: condition_id: Stable condition ID. factor_assignments: Materialized factor assignments. block_assignments: Materialized block assignments. metadata: Optional metadata payload. admissible: Constraint admissibility flag. constraint_messages: Constraint warning/error messages. """ condition_id: str factor_assignments: dict[str, Any] block_assignments: dict[str, Any] metadata: dict[str, Any] = field(default_factory=dict) admissible: bool = True constraint_messages: list[str] = field(default_factory=list)
def evaluate_constraint_expression(expression: str, context: Mapping[str, Any]) -> bool: """Evaluate one safe boolean expression against an assignment context.""" parsed = ast.parse(expression, mode="eval") evaluated = _eval_ast_node(parsed.body, context) if not isinstance(evaluated, bool): raise ValidationError("Constraint expression must evaluate to a boolean value.") return evaluated def _eval_ast_node(node: ast.AST, context: Mapping[str, Any]) -> Any: """Evaluate a small safe AST subset for constraint expressions.""" if isinstance(node, ast.Constant): return node.value if isinstance(node, ast.Name): if node.id not in context: raise ValidationError(f"Unknown variable '{node.id}' in constraint expression.") return context[node.id] if isinstance(node, ast.List): return [_eval_ast_node(element, context) for element in node.elts] if isinstance(node, ast.Tuple): return tuple(_eval_ast_node(element, context) for element in node.elts) if isinstance(node, ast.Dict): keys: list[Any] = [] for element in node.keys: if element is None: raise ValidationError("Dictionary unpacking is not allowed in constraints.") keys.append(_eval_ast_node(element, context)) values = [_eval_ast_node(element, context) for element in node.values] return dict(zip(keys, values, strict=True)) if isinstance(node, ast.Attribute): base = _eval_ast_node(node.value, context) if isinstance(base, Mapping): return base.get(node.attr) return getattr(base, node.attr) if isinstance(node, ast.Subscript): target = _eval_ast_node(node.value, context) if node.slice is None: raise ValidationError("Subscript expressions require an index.") index = _eval_ast_node(node.slice, context) return target[index] if isinstance(node, ast.BoolOp): values = [_eval_ast_node(value, context) for value in node.values] if isinstance(node.op, ast.And): return all(bool(value) for value in values) if isinstance(node.op, ast.Or): return any(bool(value) for value in values) if isinstance(node, ast.UnaryOp): operand = _eval_ast_node(node.operand, context) if isinstance(node.op, ast.Not): return not bool(operand) if isinstance(node.op, ast.UAdd): return +operand if isinstance(node.op, ast.USub): return -operand if isinstance(node, ast.BinOp): left = _eval_ast_node(node.left, context) right = _eval_ast_node(node.right, context) if isinstance(node.op, ast.Add): return left + right if isinstance(node.op, ast.Sub): return left - right if isinstance(node.op, ast.Mult): return left * right if isinstance(node.op, ast.Div): return left / right if isinstance(node.op, ast.Mod): return left % right if isinstance(node, ast.Compare): left = _eval_ast_node(node.left, context) for operator_node, comparator_node in zip(node.ops, node.comparators, strict=True): right = _eval_ast_node(comparator_node, context) if not _eval_comparison(operator_node, left, right): return False left = right return True if isinstance(node, ast.Call): if not isinstance(node.func, ast.Name): raise ValidationError("Only direct function names are allowed in constraints.") function_name = node.func.id allowed_functions: dict[str, Any] = { "len": len, "int": int, "float": float, "str": str, "bool": bool, } if function_name not in allowed_functions: raise ValidationError(f"Function '{function_name}' is not allowed in constraints.") positional_args = [_eval_ast_node(arg, context) for arg in node.args] if node.keywords: raise ValidationError("Keyword arguments are not allowed in constraint function calls.") return allowed_functions[function_name](*positional_args) raise ValidationError(f"Unsupported expression node: {type(node).__name__}") def _eval_comparison(operator_node: ast.AST, left: Any, right: Any) -> bool: """Evaluate one comparison operator.""" if isinstance(operator_node, ast.Eq): return bool(left == right) if isinstance(operator_node, ast.NotEq): return bool(left != right) if isinstance(operator_node, ast.Lt): return bool(left < right) if isinstance(operator_node, ast.LtE): return bool(left <= right) if isinstance(operator_node, ast.Gt): return bool(left > right) if isinstance(operator_node, ast.GtE): return bool(left >= right) if isinstance(operator_node, ast.In): return bool(left in right) if isinstance(operator_node, ast.NotIn): return bool(left not in right) raise ValidationError(f"Unsupported comparison operator: {type(operator_node).__name__}")
[docs] def materialize_conditions( factors: Sequence[Factor] | Study, blocks: Sequence[Block] | None = None, constraints: Sequence[Constraint] | None = None, *, seed: int | None = None, randomize: bool = False, counterbalance: bool = False, ) -> list[Condition]: """Materialize admissible conditions from factors, blocks, and constraints.""" resolved_factors, resolved_blocks, resolved_constraints = _normalize_inputs( factors=factors, blocks=blocks, constraints=constraints, ) factor_names = [factor.name for factor in resolved_factors] factor_level_values = [factor.iter_values() for factor in resolved_factors] block_names = [block.name for block in resolved_blocks] block_values = [tuple(block.levels) for block in resolved_blocks] factor_assignments: list[dict[str, Any]] = [] for combination in itertools.product(*factor_level_values): factor_assignments.append(dict(zip(factor_names, combination, strict=True))) block_assignments: list[dict[str, Any]] = [{}] if block_values: block_assignments = [] for combination in itertools.product(*block_values): block_assignments.append(dict(zip(block_names, combination, strict=True))) conditions: list[Condition] = [] for factor_assignment in factor_assignments: for block_assignment in block_assignments: condition = _build_condition( factor_assignment=factor_assignment, block_assignment=block_assignment, constraints=resolved_constraints, ) conditions.append(condition) if counterbalance: conditions = counterbalance_conditions(conditions) if randomize: randomizer = random.Random(seed) randomizer.shuffle(conditions) return conditions
def balanced_randomization_schedule( condition_ids: Sequence[str], replicates: int, *, seed: int | None = None, ) -> list[tuple[str, int, int]]: """Create a balanced randomized `(condition_id, replicate, order)` schedule.""" if replicates < 1: raise ValidationError("Replicates must be >= 1.") if not condition_ids: return [] randomizer = random.Random(seed) base_order = list(condition_ids) randomizer.shuffle(base_order) schedule: list[tuple[str, int, int]] = [] for replicate in range(1, replicates + 1): rotate_by = (replicate - 1) % len(base_order) replicate_order = base_order[rotate_by:] + base_order[:rotate_by] for order_index, condition_id in enumerate(replicate_order): schedule.append((condition_id, replicate, order_index)) return schedule def counterbalance_conditions(conditions: Sequence[Condition]) -> list[Condition]: """Apply a simple front/back interleaving for order counterbalancing.""" ordered = list(conditions) if len(ordered) < 3: return ordered first_half = ordered[::2] second_half = ordered[1::2] second_half.reverse() return first_half + second_half def _normalize_inputs( factors: Sequence[Factor] | Study, blocks: Sequence[Block] | None, constraints: Sequence[Constraint] | None, ) -> tuple[list[Factor], list[Block], list[Constraint]]: """Resolve whether inputs were passed directly or via a `Study` object.""" from .study import Study if isinstance(factors, Study): study = factors return list(study.factors), list(study.blocks), list(study.constraints) return list(factors), list(blocks or ()), list(constraints or ()) def _build_condition( factor_assignment: Mapping[str, Any], block_assignment: Mapping[str, Any], constraints: Sequence[Constraint], ) -> Condition: """Build one condition and evaluate all constraints.""" admissible = True messages: list[str] = [] for constraint in constraints: is_valid = constraint.evaluate(factors=factor_assignment, blocks=block_assignment) if is_valid: continue message = f"{constraint.constraint_id}: {constraint.description}" messages.append(message) if constraint.severity == ConstraintSeverity.ERROR: admissible = False condition_id = hash_identifier( "cond", { "factors": factor_assignment, "blocks": block_assignment, }, ) return Condition( condition_id=condition_id, factor_assignments=dict(factor_assignment), block_assignments=dict(block_assignment), metadata={ "fingerprint": stable_json_dumps( { "factors": factor_assignment, "blocks": block_assignment, } ) }, admissible=admissible, constraint_messages=messages, ) def _coerce_factor(value: Any) -> Factor: """Coerce a loose object into a `Factor` instance.""" if isinstance(value, Factor): return value mapping = cast(Mapping[str, Any], value) levels = tuple( level if isinstance(level, Level) else Level(**cast(dict[str, Any], level)) for level in cast(Sequence[Any], mapping.get("levels", ())) ) return Factor( name=str(mapping["name"]), description=str(mapping.get("description", "")), kind=FactorKind(str(mapping.get("kind", FactorKind.MANIPULATED.value))), levels=levels, dtype=cast(str | None, mapping.get("dtype")), default=mapping.get("default"), metadata=dict(cast(Mapping[str, Any], mapping.get("metadata", {}))), )