"""Problem-layer adapter utilities built on public problem APIs."""
from __future__ import annotations
import importlib
import random
from collections.abc import Callable, Mapping, Sequence
from dataclasses import dataclass, field, fields, is_dataclass
from typing import Any, cast
from ..schemas import ValidationError
[docs]
@dataclass(slots=True)
class ProblemPacket:
"""Normalized executable problem payload."""
problem_id: str
family: str
brief: str
payload: dict[str, Any] = field(default_factory=dict)
metadata: dict[str, Any] = field(default_factory=dict)
evaluator: Callable[[Mapping[str, Any]], Any] | None = None
[docs]
def resolve_problem(
problem_spec_ref: Any,
*,
registry: Mapping[str, ProblemPacket] | None = None,
) -> ProblemPacket:
"""Resolve a problem reference into a normalized packet."""
if isinstance(problem_spec_ref, ProblemPacket):
return problem_spec_ref
if isinstance(problem_spec_ref, str):
if registry and problem_spec_ref in registry:
return registry[problem_spec_ref]
packet = _resolve_from_design_research_problems(problem_spec_ref)
if packet is not None:
return packet
return ProblemPacket(
problem_id=problem_spec_ref,
family="unknown",
brief=problem_spec_ref,
payload={"problem_id": problem_spec_ref},
metadata={},
)
if isinstance(problem_spec_ref, Mapping):
return ProblemPacket(
problem_id=str(problem_spec_ref.get("problem_id", "problem")),
family=str(problem_spec_ref.get("family", "unknown")),
brief=str(problem_spec_ref.get("brief", "")),
payload=dict(cast(Mapping[str, Any], problem_spec_ref.get("payload", {}))),
metadata=dict(cast(Mapping[str, Any], problem_spec_ref.get("metadata", {}))),
evaluator=cast(
Callable[[Mapping[str, Any]], Any] | None,
problem_spec_ref.get("evaluator"),
),
)
return _packet_from_object(problem_spec_ref)
def evaluate_problem(
packet: ProblemPacket,
run_output: Mapping[str, Any],
) -> list[dict[str, Any]]:
"""Execute family-specific evaluation when available and normalize rows."""
if packet.evaluator is None:
return []
evaluator_input = _resolve_evaluator_input(packet, run_output)
raw = packet.evaluator(evaluator_input)
return _normalize_evaluation_payload(raw)
def sample_problem_packets(
problem_refs: Sequence[Any],
*,
registry: Mapping[str, ProblemPacket] | None = None,
sample_size: int | None = None,
seed: int = 0,
balanced_by_family: bool = False,
) -> list[ProblemPacket]:
"""Resolve and sample problem packets with optional family balancing."""
resolved = [resolve_problem(problem_ref, registry=registry) for problem_ref in problem_refs]
if sample_size is None or sample_size >= len(resolved):
return resolved
if not balanced_by_family:
randomizer = random.Random(seed)
return randomizer.sample(resolved, sample_size)
buckets: dict[str, list[ProblemPacket]] = {}
for packet in resolved:
buckets.setdefault(packet.family, []).append(packet)
randomizer = random.Random(seed)
for bucket in buckets.values():
randomizer.shuffle(bucket)
sampled: list[ProblemPacket] = []
families = sorted(buckets)
while families and len(sampled) < sample_size:
next_families: list[str] = []
for family in families:
bucket = buckets[family]
if not bucket:
continue
sampled.append(bucket.pop())
if len(sampled) >= sample_size:
break
if bucket:
next_families.append(family)
families = next_families
return sampled
def _resolve_from_design_research_problems(problem_id: str) -> ProblemPacket | None:
"""Attempt resolving from the upstream design-research-problems package."""
try:
module = importlib.import_module("design_research_problems")
get_problem = getattr(module, "get_problem", None)
except Exception:
return None
if not callable(get_problem):
return None
try:
problem_obj = get_problem(problem_id)
except Exception:
return None
return _packet_from_object(problem_obj, fallback_problem_id=problem_id)
def _packet_from_object(problem_obj: Any, fallback_problem_id: str | None = None) -> ProblemPacket:
"""Normalize an arbitrary problem-like object into a `ProblemPacket`."""
metadata_object = getattr(problem_obj, "metadata", None)
problem_id = _stringify_first(
getattr(metadata_object, "problem_id", None),
getattr(problem_obj, "problem_id", None),
fallback_problem_id,
"problem",
)
family = _stringify_first(
_value_or_enum(getattr(metadata_object, "kind", None)),
getattr(problem_obj, "family", None),
problem_obj.__class__.__name__,
)
brief = _resolve_problem_brief(problem_obj, fallback=problem_id)
evaluator = getattr(problem_obj, "evaluate", None)
if evaluator is not None and not callable(evaluator):
raise ValidationError("Problem evaluator must be callable when present.")
payload = {
"problem_object": problem_obj,
}
metadata = {
"problem_class": problem_obj.__class__.__name__,
}
metadata.update(_extract_problem_metadata(problem_obj))
return ProblemPacket(
problem_id=problem_id,
family=family,
brief=brief,
payload=payload,
metadata=metadata,
evaluator=cast(Callable[[Mapping[str, Any]], Any] | None, evaluator),
)
def _resolve_problem_brief(problem_obj: Any, *, fallback: str) -> str:
"""Resolve the richest available human-readable brief for a problem object."""
render_brief = getattr(problem_obj, "render_brief", None)
if callable(render_brief):
try:
rendered = render_brief()
except TypeError:
rendered = None
except Exception:
rendered = None
normalized = _normalize_text(rendered)
if normalized is not None:
return normalized
for attribute_name in ("statement_markdown", "brief", "prompt"):
normalized = _normalize_text(getattr(problem_obj, attribute_name, None))
if normalized is not None:
return normalized
metadata_object = getattr(problem_obj, "metadata", None)
normalized_summary = _normalize_text(getattr(metadata_object, "summary", None))
if normalized_summary is not None:
return normalized_summary
normalized_title = _normalize_text(getattr(metadata_object, "title", None))
if normalized_title is not None:
return normalized_title
return fallback
def _extract_problem_metadata(problem_obj: Any) -> dict[str, Any]:
"""Extract interoperable metadata from a packaged problem-like object."""
metadata_object = getattr(problem_obj, "metadata", None)
metadata: dict[str, Any] = {}
title = _normalize_text(getattr(metadata_object, "title", None))
if title is not None:
metadata["title"] = title
summary = _normalize_text(getattr(metadata_object, "summary", None))
if summary is not None:
metadata["summary"] = summary
problem_kind = _value_or_enum(getattr(metadata_object, "kind", None))
normalized_kind = _normalize_text(problem_kind)
if normalized_kind is not None:
metadata["problem_kind"] = normalized_kind
capabilities = _string_sequence(getattr(metadata_object, "capabilities", None))
if capabilities:
metadata["capabilities"] = capabilities
study_suitability = _string_sequence(getattr(metadata_object, "study_suitability", None))
if study_suitability:
metadata["study_suitability"] = study_suitability
feature_flags = _string_sequence(getattr(metadata_object, "feature_flags", None))
if feature_flags:
metadata["feature_flags"] = feature_flags
implementation = _normalize_text(getattr(metadata_object, "implementation", None))
if implementation is not None:
metadata["implementation"] = implementation
return metadata
def _resolve_evaluator_input(packet: ProblemPacket, run_output: Mapping[str, Any]) -> Any:
"""Resolve the best evaluator input for packaged and external problem evaluators."""
preferred_keys = ("candidate", "state", "answer", "solution", "final_answer", "x")
for key in preferred_keys:
if key in run_output:
return run_output[key]
return run_output
def _normalize_evaluation_payload(raw: Any) -> list[dict[str, Any]]:
"""Normalize evaluator payloads into canonical experiment evaluation rows."""
if isinstance(raw, Mapping):
if _looks_like_evaluation_row(raw):
return [_normalize_evaluation_row(raw)]
return _metric_rows_from_mapping(raw)
if isinstance(raw, Sequence) and not isinstance(raw, (str, bytes)):
rows: list[dict[str, Any]] = []
for row in raw:
rows.extend(_normalize_evaluation_payload(row))
return rows
mapping = _object_to_mapping(raw)
if mapping is None:
return []
return _metric_rows_from_mapping(mapping)
def _looks_like_evaluation_row(row: Mapping[str, Any]) -> bool:
"""Return whether a mapping already resembles one canonical evaluation row."""
return any(key in row for key in ("metric_name", "metric_value", "value"))
def _metric_rows_from_mapping(metrics: Mapping[str, Any]) -> list[dict[str, Any]]:
"""Expand a metrics mapping into canonical evaluation rows."""
rows: list[dict[str, Any]] = []
for metric_name, metric_value in metrics.items():
if str(metric_name) == "higher_is_better":
continue
if not _is_metric_scalar(metric_value):
continue
rows.append(
{
"evaluator_id": "problem_evaluator",
"metric_name": str(metric_name),
"metric_value": metric_value,
"metric_unit": "unitless",
"aggregation_level": "run",
"notes_json": {},
}
)
return rows
def _object_to_mapping(value: Any) -> Mapping[str, Any] | None:
"""Best-effort conversion of an evaluation object to a flat mapping."""
if value is None:
return None
if isinstance(value, Mapping):
return value
if is_dataclass(value) and not isinstance(value, type):
return {field_info.name: getattr(value, field_info.name) for field_info in fields(value)}
to_dict = getattr(value, "to_dict", None)
if callable(to_dict):
try:
candidate = to_dict()
except Exception:
candidate = None
if isinstance(candidate, Mapping):
return cast(Mapping[str, Any], candidate)
if hasattr(value, "__dict__"):
return cast(Mapping[str, Any], vars(value))
return None
def _normalize_evaluation_row(row: Mapping[str, Any]) -> dict[str, Any]:
"""Normalize one evaluator row to canonical shape."""
return {
"evaluator_id": str(row.get("evaluator_id", "problem_evaluator")),
"metric_name": str(row.get("metric_name", "score")),
"metric_value": row.get("metric_value", row.get("value")),
"metric_unit": str(row.get("metric_unit", "unitless")),
"aggregation_level": str(row.get("aggregation_level", "run")),
"notes_json": row.get("notes_json", {}),
}
def _value_or_enum(value: Any) -> Any:
"""Return an enum's value when present, otherwise the original value."""
enum_value = getattr(value, "value", None)
if enum_value is not None:
return enum_value
return value
def _stringify_first(*values: Any) -> str:
"""Return the first non-empty stringified value."""
for value in values:
normalized = _normalize_text(value)
if normalized is not None:
return normalized
return ""
def _string_sequence(value: Any) -> tuple[str, ...]:
"""Normalize a loose sequence of values to a stable string tuple."""
if not isinstance(value, Sequence) or isinstance(value, (str, bytes)):
return ()
return tuple(str(item) for item in value if _normalize_text(item) is not None)
def _normalize_text(value: Any) -> str | None:
"""Normalize one optional value to non-empty text."""
if value is None:
return None
normalized = str(value).strip()
return normalized or None
def _is_metric_scalar(value: Any) -> bool:
"""Return whether one value is suitable for scalar metric export."""
return isinstance(value, bool | int | float)