"""Problem registry facade."""
from __future__ import annotations
from collections.abc import Callable
from importlib import import_module
from typing import TYPE_CHECKING, Literal, TypeVar, cast, overload
from design_research_problems._catalog._loader import load_problem_manifests
from design_research_problems._catalog._manifest import ProblemManifest
from design_research_problems._exceptions import ProblemEvaluationError
from design_research_problems.problems import MCPProblem, Problem, ProblemKind, TextProblem
from design_research_problems.problems._decision import load_decision_problem
from design_research_problems.problems._metadata import ProblemMetadata
if TYPE_CHECKING:
from design_research_problems.problems import (
DecisionProblem,
MCPProblem,
OptimizationProblem,
)
from design_research_problems.problems.grammar import (
BatteryPack18650OpenEndedProblem,
BatteryPack18650SeriesParallelProblem,
IoTHomeCoolingGrammarProblem,
PlanarTrussSpanProblem,
SpaceTrussSpanProblem,
TrussAPGrammarProblem,
)
from design_research_problems.problems.optimization import (
BatteryGridSizingProblem,
BatteryOpenEndedCapacityMaxProblem,
GMPBOptimizationProblem,
PlanarTrussEngineeringOptimizationProblem,
SpaceTrussEngineeringOptimizationProblem,
)
type _PlanarTrussProblemId = Literal[
"planar_truss_span",
"planar_roof_truss_three_point_symmetric",
"planar_roof_truss_three_point_symmetric_depth_sixth",
"planar_roof_truss_three_point_symmetric_depth_sixth_discrete_sizing",
"planar_roof_truss_three_point_symmetric_depth_eighth",
"planar_roof_truss_seven_point_symmetric",
"planar_roof_truss_seven_point_asymmetric",
]
type _SpaceTrussProblemId = Literal["space_truss_span"]
type _IoTGrammarProblemId = Literal["iot_home_cooling_system_design"]
type _TrussAPProblemId = Literal["truss_analysis_program_design"]
type _MSEvalProblemId = Literal[
"decision_mseval_kitchen_utensil_grip_corrosion_resistant",
"decision_mseval_kitchen_utensil_grip_high_strength",
"decision_mseval_kitchen_utensil_grip_lightweight",
"decision_mseval_kitchen_utensil_grip_resistant_to_heat",
"decision_mseval_safety_helmet_corrosion_resistant",
"decision_mseval_safety_helmet_high_strength",
"decision_mseval_safety_helmet_lightweight",
"decision_mseval_safety_helmet_resistant_to_heat",
"decision_mseval_spacecraft_component_corrosion_resistant",
"decision_mseval_spacecraft_component_high_strength",
"decision_mseval_spacecraft_component_lightweight",
"decision_mseval_spacecraft_component_resistant_to_heat",
"decision_mseval_underwater_component_corrosion_resistant",
"decision_mseval_underwater_component_high_strength",
"decision_mseval_underwater_component_lightweight",
"decision_mseval_underwater_component_resistant_to_heat",
]
type _DecisionProblemId = Literal["decision_laptop_design_profit_maximization"] | _MSEvalProblemId
type _OptimizationProblemId = Literal[
"battery_pack_18650_series_parallel_cost_min",
"battery_pack_18650_open_ended_capacity_max",
"gmpb_default_dynamic_min",
"pill_capsule_min_area",
"moneymaker_hip_pump_cost_min",
"planar_truss_span_mass_min",
"planar_truss_span_deflection_min",
"planar_truss_span_fos_max",
"space_truss_span_mass_min",
"treadle_pump_ide_material_min",
]
_ProblemSubT = TypeVar("_ProblemSubT", bound=Problem)
def _resolve_object(import_path: str) -> object:
"""Import one ``module:attribute`` object reference.
Args:
import_path: Import target in ``module:attribute`` form.
Returns:
Imported Python object.
Raises:
design_research_problems.ProblemEvaluationError: If the import path is malformed.
"""
module_path, _, attr_name = import_path.partition(":")
if not module_path or not attr_name:
raise ProblemEvaluationError(f"Invalid implementation path: {import_path!r}")
return getattr(import_module(module_path), attr_name)
[docs]
class ProblemRegistry:
"""Lazy-loading registry for the packaged problem catalog."""
def __init__(self) -> None:
"""Initialize an empty lazy registry cache."""
self._manifests: dict[str, ProblemManifest] | None = None
def _catalog(self) -> dict[str, ProblemManifest]:
"""Return the cached manifest catalog.
Returns:
Mapping of problem IDs to parsed manifests.
"""
if self._manifests is None:
self._manifests = load_problem_manifests()
return self._manifests
[docs]
def list(self) -> tuple[ProblemMetadata, ...]:
"""Return all problem metadata in ID-sorted order.
Returns:
Metadata entries sorted by problem ID.
"""
return tuple(self._catalog()[problem_id].metadata for problem_id in sorted(self._catalog()))
[docs]
def by_kind(self, kind: ProblemKind) -> tuple[ProblemMetadata, ...]:
"""Return metadata entries with the requested kind.
Args:
kind: Problem family to filter for.
Returns:
Matching metadata entries.
"""
return tuple(metadata for metadata in self.list() if metadata.kind is kind)
[docs]
def feature_flags(self, problem_id: str) -> tuple[str, ...]:
"""Return the feature flags for one problem ID.
Args:
problem_id: Stable catalog identifier.
Returns:
Feature flags in deterministic sorted order.
Raises:
KeyError: If the ID is unknown.
"""
manifest = self._catalog().get(problem_id)
if manifest is None:
raise KeyError(f"Unknown problem id: {problem_id}")
return manifest.metadata.feature_flags
[docs]
def capabilities(self, problem_id: str) -> tuple[str, ...]:
"""Return the normalized capability flags for one problem ID.
Args:
problem_id: Stable catalog identifier.
Returns:
Capability flags in deterministic sorted order.
Raises:
KeyError: If the ID is unknown.
"""
manifest = self._catalog().get(problem_id)
if manifest is None:
raise KeyError(f"Unknown problem id: {problem_id}")
return manifest.metadata.capabilities
[docs]
def study_suitability(self, problem_id: str) -> tuple[str, ...]:
"""Return the normalized study-suitability flags for one problem ID.
Args:
problem_id: Stable catalog identifier.
Returns:
Study-suitability flags in deterministic sorted order.
Raises:
KeyError: If the ID is unknown.
"""
manifest = self._catalog().get(problem_id)
if manifest is None:
raise KeyError(f"Unknown problem id: {problem_id}")
return manifest.metadata.study_suitability
[docs]
def kind_feature_flags(self) -> dict[ProblemKind, tuple[str, ...]]:
"""Return aggregated feature flags for each problem family.
Returns:
Mapping of problem kinds to the union of feature flags across that family.
"""
grouped: dict[ProblemKind, set[str]] = {kind: set() for kind in ProblemKind}
for metadata in self.list():
grouped[metadata.kind].update(metadata.feature_flags)
return {kind: tuple(sorted(grouped[kind])) for kind in ProblemKind}
[docs]
def search(
self,
tags: tuple[str, ...] = (),
text: str = "",
feature_flags: tuple[str, ...] = (),
kind: ProblemKind | None = None,
capabilities: tuple[str, ...] = (),
study_suitability: tuple[str, ...] = (),
) -> tuple[ProblemMetadata, ...]:
"""Search metadata by tags and free text.
Args:
tags: Tags that must all be present on a matching entry.
text: Case-insensitive free-text search term.
feature_flags: Feature flags that must all be present on a matching entry.
kind: Optional problem-family filter.
capabilities: Capability flags that must all be present.
study_suitability: Study-suitability flags that must all be present.
Returns:
Matching metadata entries.
"""
tag_set = {tag.lower() for tag in tags}
feature_flag_set = {flag.strip().lower().replace(" ", "-") for flag in feature_flags}
capability_set = {flag.strip().lower().replace(" ", "-") for flag in capabilities}
suitability_set = {flag.strip().lower().replace(" ", "-") for flag in study_suitability}
text_query = text.strip().lower()
matches: list[ProblemMetadata] = []
for metadata in self.list():
if kind is not None and metadata.kind is not kind:
continue
if tag_set and not tag_set.issubset({tag.lower() for tag in metadata.taxonomy.tags}):
continue
if feature_flag_set and not feature_flag_set.issubset(set(metadata.feature_flags)):
continue
if capability_set and not capability_set.issubset(set(metadata.capabilities)):
continue
if suitability_set and not suitability_set.issubset(set(metadata.study_suitability)):
continue
haystack = " ".join(
(metadata.problem_id, metadata.title, metadata.summary, *metadata.taxonomy.tags)
).lower()
if text_query and text_query not in haystack:
continue
matches.append(metadata)
return tuple(matches)
@overload
def get(
self, problem_id: Literal["battery_pack_18650_series_parallel"]
) -> BatteryPack18650SeriesParallelProblem: ...
@overload
def get(self, problem_id: Literal["battery_pack_18650_open_ended"]) -> BatteryPack18650OpenEndedProblem: ...
@overload
def get(self, problem_id: _PlanarTrussProblemId) -> PlanarTrussSpanProblem: ...
@overload
def get(self, problem_id: Literal["battery_pack_18650_series_parallel_cost_min"]) -> BatteryGridSizingProblem: ...
@overload
def get(
self, problem_id: Literal["battery_pack_18650_open_ended_capacity_max"]
) -> BatteryOpenEndedCapacityMaxProblem: ...
@overload
def get(self, problem_id: Literal["gmpb_default_dynamic_min"]) -> GMPBOptimizationProblem: ...
@overload
def get(self, problem_id: _SpaceTrussProblemId) -> SpaceTrussSpanProblem: ...
@overload
def get(self, problem_id: _IoTGrammarProblemId) -> IoTHomeCoolingGrammarProblem: ...
@overload
def get(self, problem_id: _TrussAPProblemId) -> TrussAPGrammarProblem: ...
@overload
def get(
self,
problem_id: Literal[
"planar_truss_span_mass_min",
"planar_truss_span_deflection_min",
"planar_truss_span_fos_max",
],
) -> PlanarTrussEngineeringOptimizationProblem: ...
@overload
def get(self, problem_id: Literal["space_truss_span_mass_min"]) -> SpaceTrussEngineeringOptimizationProblem: ...
@overload
def get(self, problem_id: _OptimizationProblemId) -> OptimizationProblem: ...
@overload
def get(
self,
problem_id: _DecisionProblemId,
) -> DecisionProblem: ...
@overload
def get(self, problem_id: str) -> Problem: ...
[docs]
def get(self, problem_id: str) -> Problem:
"""Instantiate one problem by ID.
Args:
problem_id: Stable catalog identifier.
Returns:
Loaded problem instance.
Raises:
KeyError: If the ID is unknown.
design_research_problems.ProblemEvaluationError: If the packaged implementation metadata is invalid.
"""
manifest = self._catalog().get(problem_id)
if manifest is None:
raise KeyError(f"Unknown problem id: {problem_id}")
implementation = manifest.metadata.implementation
if implementation is not None:
target = _resolve_object(implementation)
factory = getattr(target, "from_manifest", None)
if callable(factory):
manifest_factory = cast(Callable[[ProblemManifest], Problem], factory)
return manifest_factory(manifest)
if callable(target):
direct_factory = cast(Callable[[ProblemManifest], Problem], target)
return direct_factory(manifest)
raise ProblemEvaluationError(f"Problem implementation for {problem_id!r} is not callable.")
if manifest.metadata.kind is ProblemKind.TEXT:
return TextProblem.from_manifest(manifest)
if manifest.metadata.kind is ProblemKind.DECISION:
return load_decision_problem(manifest)
if manifest.metadata.kind is ProblemKind.MCP:
return MCPProblem.from_manifest(manifest)
raise ProblemEvaluationError(f"Problem {problem_id!r} is missing an implementation path.")
[docs]
def get_as(self, problem_id: str, expected_type: type[_ProblemSubT]) -> _ProblemSubT:
"""Instantiate one problem by ID and assert the runtime type.
Args:
problem_id: Stable catalog identifier.
expected_type: Required runtime problem class.
Returns:
Loaded problem instance narrowed to ``expected_type``.
Raises:
TypeError: If the loaded problem is not an instance of ``expected_type``.
"""
problem = self.get(problem_id)
if not isinstance(problem, expected_type):
raise TypeError(
f"Problem {problem_id!r} resolved to {type(problem).__name__}, expected {expected_type.__name__}."
)
return problem
_DEFAULT_REGISTRY = ProblemRegistry()
[docs]
def list_problems() -> tuple[str, ...]:
"""Return all packaged problem IDs.
Returns:
Stable problem IDs in sorted order.
"""
return tuple(metadata.problem_id for metadata in _DEFAULT_REGISTRY.list())
@overload
def get_problem(problem_id: Literal["battery_pack_18650_series_parallel"]) -> BatteryPack18650SeriesParallelProblem: ...
@overload
def get_problem(problem_id: Literal["battery_pack_18650_open_ended"]) -> BatteryPack18650OpenEndedProblem: ...
@overload
def get_problem(problem_id: _PlanarTrussProblemId) -> PlanarTrussSpanProblem: ...
@overload
def get_problem(problem_id: Literal["battery_pack_18650_series_parallel_cost_min"]) -> BatteryGridSizingProblem: ...
@overload
def get_problem(
problem_id: Literal["battery_pack_18650_open_ended_capacity_max"],
) -> BatteryOpenEndedCapacityMaxProblem: ...
@overload
def get_problem(problem_id: Literal["gmpb_default_dynamic_min"]) -> GMPBOptimizationProblem: ...
@overload
def get_problem(problem_id: _SpaceTrussProblemId) -> SpaceTrussSpanProblem: ...
@overload
def get_problem(problem_id: _IoTGrammarProblemId) -> IoTHomeCoolingGrammarProblem: ...
@overload
def get_problem(problem_id: _TrussAPProblemId) -> TrussAPGrammarProblem: ...
@overload
def get_problem(
problem_id: Literal[
"planar_truss_span_mass_min",
"planar_truss_span_deflection_min",
"planar_truss_span_fos_max",
],
) -> PlanarTrussEngineeringOptimizationProblem: ...
@overload
def get_problem(problem_id: Literal["space_truss_span_mass_min"]) -> SpaceTrussEngineeringOptimizationProblem: ...
@overload
def get_problem(problem_id: _OptimizationProblemId) -> OptimizationProblem: ...
@overload
def get_problem(problem_id: _DecisionProblemId) -> DecisionProblem: ...
@overload
def get_problem(problem_id: str) -> Problem: ...
[docs]
def get_problem(problem_id: str) -> Problem:
"""Return one problem instance by ID.
Args:
problem_id: Stable catalog identifier.
Returns:
Loaded problem instance.
"""
return _DEFAULT_REGISTRY.get(problem_id)
[docs]
def get_problem_as[ProblemSubT: Problem](problem_id: str, expected_type: type[ProblemSubT]) -> ProblemSubT:
"""Return one problem instance by ID with an asserted runtime type.
Args:
problem_id: Stable catalog identifier.
expected_type: Required runtime problem class.
Returns:
Loaded problem instance narrowed to ``expected_type``.
"""
return _DEFAULT_REGISTRY.get_as(problem_id, expected_type)