Source code for design_research_experiments.adapters.problems

"""Problem-layer adapter utilities built on public problem APIs."""

from __future__ import annotations

import importlib
import random
from collections.abc import Callable, Mapping, Sequence
from dataclasses import dataclass, field
from typing import Any, cast

from ..schemas import ValidationError


[docs] @dataclass(slots=True) class ProblemPacket: """Normalized executable problem payload.""" problem_id: str family: str brief: str payload: dict[str, Any] = field(default_factory=dict) metadata: dict[str, Any] = field(default_factory=dict) evaluator: Callable[[Mapping[str, Any]], Any] | None = None
def resolve_problem( problem_spec_ref: Any, *, registry: Mapping[str, ProblemPacket] | None = None, ) -> ProblemPacket: """Resolve a problem reference into a normalized packet.""" if isinstance(problem_spec_ref, ProblemPacket): return problem_spec_ref if isinstance(problem_spec_ref, str): if registry and problem_spec_ref in registry: return registry[problem_spec_ref] packet = _resolve_from_design_research_problems(problem_spec_ref) if packet is not None: return packet return ProblemPacket( problem_id=problem_spec_ref, family="unknown", brief=problem_spec_ref, payload={"problem_id": problem_spec_ref}, metadata={}, ) if isinstance(problem_spec_ref, Mapping): return ProblemPacket( problem_id=str(problem_spec_ref.get("problem_id", "problem")), family=str(problem_spec_ref.get("family", "unknown")), brief=str(problem_spec_ref.get("brief", "")), payload=dict(cast(Mapping[str, Any], problem_spec_ref.get("payload", {}))), metadata=dict(cast(Mapping[str, Any], problem_spec_ref.get("metadata", {}))), evaluator=cast( Callable[[Mapping[str, Any]], Any] | None, problem_spec_ref.get("evaluator"), ), ) return _packet_from_object(problem_spec_ref) def evaluate_problem( packet: ProblemPacket, run_output: Mapping[str, Any], ) -> list[dict[str, Any]]: """Execute family-specific evaluation when available and normalize rows.""" if packet.evaluator is None: return [] raw = packet.evaluator(run_output) if isinstance(raw, Mapping): return [_normalize_evaluation_row(raw)] if isinstance(raw, Sequence): rows: list[dict[str, Any]] = [] for row in raw: rows.append(_normalize_evaluation_row(cast(Mapping[str, Any], row))) return rows return [] def sample_problem_packets( problem_refs: Sequence[Any], *, registry: Mapping[str, ProblemPacket] | None = None, sample_size: int | None = None, seed: int = 0, balanced_by_family: bool = False, ) -> list[ProblemPacket]: """Resolve and sample problem packets with optional family balancing.""" resolved = [resolve_problem(problem_ref, registry=registry) for problem_ref in problem_refs] if sample_size is None or sample_size >= len(resolved): return resolved if not balanced_by_family: randomizer = random.Random(seed) return randomizer.sample(resolved, sample_size) buckets: dict[str, list[ProblemPacket]] = {} for packet in resolved: buckets.setdefault(packet.family, []).append(packet) randomizer = random.Random(seed) for bucket in buckets.values(): randomizer.shuffle(bucket) sampled: list[ProblemPacket] = [] families = sorted(buckets) while families and len(sampled) < sample_size: next_families: list[str] = [] for family in families: bucket = buckets[family] if not bucket: continue sampled.append(bucket.pop()) if len(sampled) >= sample_size: break if bucket: next_families.append(family) families = next_families return sampled def _resolve_from_design_research_problems(problem_id: str) -> ProblemPacket | None: """Attempt resolving from the upstream design-research-problems package.""" try: module = importlib.import_module("design_research_problems") except ImportError: return None get_problem = getattr(module, "get_problem", None) if not callable(get_problem): return None try: problem_obj = get_problem(problem_id) except Exception: return None return _packet_from_object(problem_obj, fallback_problem_id=problem_id) def _packet_from_object(problem_obj: Any, fallback_problem_id: str | None = None) -> ProblemPacket: """Normalize an arbitrary problem-like object into a `ProblemPacket`.""" problem_id = str( getattr( problem_obj, "problem_id", fallback_problem_id if fallback_problem_id is not None else "problem", ) ) family = str(getattr(problem_obj, "family", problem_obj.__class__.__name__)) brief_candidate = getattr(problem_obj, "brief", None) if brief_candidate is None: brief_candidate = getattr(problem_obj, "prompt", None) brief = str(brief_candidate) if brief_candidate is not None else problem_id evaluator = getattr(problem_obj, "evaluate", None) if evaluator is not None and not callable(evaluator): raise ValidationError("Problem evaluator must be callable when present.") payload = { "problem_object": problem_obj, } metadata = { "problem_class": problem_obj.__class__.__name__, } return ProblemPacket( problem_id=problem_id, family=family, brief=brief, payload=payload, metadata=metadata, evaluator=cast(Callable[[Mapping[str, Any]], Any] | None, evaluator), ) def _normalize_evaluation_row(row: Mapping[str, Any]) -> dict[str, Any]: """Normalize one evaluator row to canonical shape.""" return { "evaluator_id": str(row.get("evaluator_id", "problem_evaluator")), "metric_name": str(row.get("metric_name", "score")), "metric_value": row.get("metric_value", row.get("value")), "metric_unit": str(row.get("metric_unit", "unitless")), "aggregation_level": str(row.get("aggregation_level", "run")), "notes_json": row.get("notes_json", {}), }