Source code for design_research_agents._implementations._patterns._tree_search_pattern

"""Tree-search reasoning pattern with pluggable generator and evaluator delegates."""

from __future__ import annotations

import json
import math
from collections.abc import Callable, Mapping, Sequence
from typing import Literal, cast

from design_research_agents._contracts._delegate import Delegate, ExecutionResult
from design_research_agents._contracts._workflow import DelegateTarget, LogicStep, LoopStep
from design_research_agents._implementations._patterns._tree_search_delegate_adapters import (
    EvaluatorCallableDelegateAdapter,
    GeneratorCallableDelegateAdapter,
    is_workflow_delegate,
)
from design_research_agents._runtime._common._delegate_invocation import invoke_delegate
from design_research_agents._runtime._patterns import (
    MODE_TREE_SEARCH,
    LoopCallbacks,
    build_compiled_pattern_execution,
    build_loop_callbacks,
    build_pattern_execution_result,
    resolve_pattern_run_context,
    wrap_iteration_handler,
)
from design_research_agents._tracing import Tracer
from design_research_agents.workflow import CompiledExecution, Workflow

GeneratorValue = Mapping[str, object] | str | int | float
GeneratorDelegate = Callable[[Mapping[str, object]], Sequence[GeneratorValue]]
EvaluatorDelegate = Callable[[Mapping[str, object]], float | int | Mapping[str, object]]

SearchStrategy = Literal["beam", "mcts"]


[docs] class TreeSearchPattern(Delegate): """Tree-search pattern with beam and MCTS strategies.""" def __init__( self, *, generator_delegate: GeneratorDelegate | DelegateTarget, evaluator_delegate: EvaluatorDelegate | DelegateTarget, max_depth: int = 3, branch_factor: int = 3, beam_width: int = 2, search_strategy: SearchStrategy = "beam", mcts_exploration_weight: float = 1.4, simulation_budget: int | None = None, tracer: Tracer | None = None, ) -> None: """Initialize tree-search reasoning pattern.""" if max_depth < 1: raise ValueError("max_depth must be >= 1.") if branch_factor < 1: raise ValueError("branch_factor must be >= 1.") if beam_width < 1: raise ValueError("beam_width must be >= 1.") if search_strategy not in {"beam", "mcts"}: raise ValueError("search_strategy must be one of: beam, mcts.") if mcts_exploration_weight <= 0: raise ValueError("mcts_exploration_weight must be > 0.") if simulation_budget is not None and simulation_budget < 1: raise ValueError("simulation_budget must be >= 1 when provided.") self._generator_delegate = _normalize_generator_delegate(generator_delegate) self._evaluator_delegate = _normalize_evaluator_delegate(evaluator_delegate) self._max_depth = max_depth self._branch_factor = branch_factor self._beam_width = beam_width self._search_strategy = search_strategy self._mcts_exploration_weight = mcts_exploration_weight self._simulation_budget = simulation_budget self._tracer = tracer self.workflow: object | None = None self._tree_search_runtime: dict[str, object] | None = None
[docs] def run( self, prompt: str | object, *, request_id: str | None = None, dependencies: Mapping[str, object] | None = None, ) -> ExecutionResult: """Execute tree search and return the highest-scoring candidate.""" return self.compile( prompt=prompt, request_id=request_id, dependencies=dependencies, ).run()
[docs] def compile( self, prompt: str | object, *, request_id: str | None = None, dependencies: Mapping[str, object] | None = None, ) -> CompiledExecution: """Compile one tree-search workflow.""" run_context = resolve_pattern_run_context( prompt=prompt, default_request_id_prefix=None, default_dependencies={}, request_id=request_id, dependencies=dependencies, ) resolved_simulation_budget = self._resolve_simulation_budget() input_payload = { **run_context.normalized_input, "mode": MODE_TREE_SEARCH, "search_strategy": self._search_strategy, "max_depth": self._max_depth, "branch_factor": self._branch_factor, "beam_width": self._beam_width, "mcts_exploration_weight": self._mcts_exploration_weight, "simulation_budget": resolved_simulation_budget, } workflow = self._build_workflow( run_context.prompt, request_id=run_context.request_id, dependencies=run_context.dependencies, ) runtime = self._tree_search_runtime or {} maybe_root_node = runtime.get("root_node") root_node = ( dict(maybe_root_node) if isinstance(maybe_root_node, Mapping) else { "node_id": "root", "candidate": {"text": run_context.prompt, "depth": 0}, "score": 0.0, "depth": 0, "parent_id": None, } ) return build_compiled_pattern_execution( workflow=workflow, pattern_name="TreeSearchPattern", request_id=run_context.request_id, dependencies=run_context.dependencies, tracer=self._tracer, input_payload=input_payload, workflow_request_id=f"{run_context.request_id}:tree_search_workflow", finalize=lambda workflow_result: _build_tree_search_result( workflow_result=workflow_result, request_id=run_context.request_id, dependencies=run_context.dependencies, max_depth=self._max_depth, branch_factor=self._branch_factor, beam_width=self._beam_width, search_strategy=self._search_strategy, mcts_exploration_weight=self._mcts_exploration_weight, simulation_budget=resolved_simulation_budget, root_node=root_node, ), )
def _resolve_simulation_budget(self) -> int: if self._simulation_budget is not None: return self._simulation_budget return max(1, self._max_depth * max(self._branch_factor, self._beam_width)) def _build_workflow( self, prompt: str, *, request_id: str, dependencies: Mapping[str, object], ) -> Workflow: """Build the tree-search workflow for one resolved run context.""" root_candidate = {"text": prompt, "depth": 0} root_node = { "node_id": "root", "candidate": root_candidate, "score": 0.0, "depth": 0, "parent_id": None, "visits": 0, "total_score": 0.0, "children": [], "expanded": False, "terminal": False, } if self._search_strategy == "beam": loop_callbacks, initial_state, max_iterations = self._build_beam_loop( prompt=prompt, request_id=request_id, dependencies=dependencies, root_node=root_node, ) else: loop_callbacks, initial_state, max_iterations = self._build_mcts_loop( prompt=prompt, request_id=request_id, dependencies=dependencies, root_node=root_node, ) workflow = Workflow( tool_runtime=None, tracer=self._tracer, input_schema={"type": "object"}, steps=[ LoopStep( step_id="tree_search_loop", steps=( LogicStep( step_id="tree_search_iteration", handler=loop_callbacks.iteration_handler, ), ), max_iterations=max_iterations, initial_state=initial_state, continue_predicate=loop_callbacks.continue_predicate, state_reducer=loop_callbacks.state_reducer, execution_mode="sequential", failure_policy="skip_dependents", ), ], ) self.workflow = workflow self._tree_search_runtime = {"root_node": dict(root_node)} return workflow def _build_beam_loop( self, *, prompt: str, request_id: str, dependencies: Mapping[str, object], root_node: Mapping[str, object], ) -> tuple[LoopCallbacks, dict[str, object], int]: """Build beam-search loop callbacks and state.""" initial_state: dict[str, object] = { "node_counter": 0, "frontier": [dict(root_node)], "frontier_trace": [], "explored_nodes": 0, "best_node": dict(root_node), "terminated_reason": "max_depth_reached", "should_continue": True, } def _run_iteration(context: Mapping[str, object]) -> Mapping[str, object]: loop_meta = context.get("_loop") depth = 1 if isinstance(loop_meta, Mapping): depth = max(1, _safe_int(loop_meta.get("iteration", 1))) loop_state = context.get("loop_state") current_state = dict(loop_state) if isinstance(loop_state, Mapping) else {} raw_frontier = current_state.get("frontier") frontier = ( [dict(node) for node in raw_frontier if isinstance(node, Mapping)] if isinstance(raw_frontier, list) else [] ) node_counter = _safe_int(current_state.get("node_counter")) frontier_trace = ( [dict(item) for item in current_state.get("frontier_trace", [])] if isinstance(current_state.get("frontier_trace"), list) else [] ) explored_nodes = _safe_int(current_state.get("explored_nodes")) maybe_best_node = current_state.get("best_node") best_node = dict(maybe_best_node) if isinstance(maybe_best_node, Mapping) else dict(root_node) if not frontier: return { **current_state, "should_continue": False, "terminated_reason": "no_expansions", } expanded_nodes: list[dict[str, object]] = [] for parent_node in frontier: children = self._generate_children( prompt=prompt, parent_node=parent_node, depth=depth, request_id=request_id, dependencies=dependencies, ) for child in children[: self._branch_factor]: node_counter += 1 child_candidate = _normalize_candidate(child) child_score = self._evaluate_candidate( prompt=prompt, candidate=child_candidate, depth=depth, request_id=request_id, dependencies=dependencies, ) explored_nodes += 1 expanded_nodes.append( { "node_id": f"node_{node_counter}", "candidate": child_candidate, "score": child_score, "depth": depth, "parent_id": parent_node.get("node_id"), } ) if not expanded_nodes: return { **current_state, "node_counter": node_counter, "explored_nodes": explored_nodes, "should_continue": False, "terminated_reason": "no_expansions", } expanded_nodes.sort( key=lambda node: ( _safe_float(node.get("score")), -_safe_int(node.get("depth")), str(node.get("node_id", "")), ), reverse=True, ) frontier = expanded_nodes[: self._beam_width] frontier_trace.append( { "depth": depth, "frontier": [ { "node_id": node["node_id"], "score": node["score"], "candidate": _json_ready(node["candidate"]), "parent_id": node["parent_id"], } for node in frontier ], } ) if frontier and _safe_float(frontier[0].get("score")) >= _safe_float(best_node.get("score")): best_node = frontier[0] should_continue = depth < self._max_depth return { "node_counter": node_counter, "frontier": frontier, "frontier_trace": frontier_trace, "explored_nodes": explored_nodes, "best_node": best_node, "should_continue": should_continue, "terminated_reason": "looping" if should_continue else "max_depth_reached", } wrapped_handler = wrap_iteration_handler( _run_iteration, error_prefix="TreeSearchPattern beam iteration", ) loop_callbacks = build_loop_callbacks( iteration_step_id="tree_search_iteration", iteration_handler=wrapped_handler, ) return loop_callbacks, initial_state, self._max_depth def _build_mcts_loop( self, *, prompt: str, request_id: str, dependencies: Mapping[str, object], root_node: Mapping[str, object], ) -> tuple[LoopCallbacks, dict[str, object], int]: """Build MCTS loop callbacks and state.""" simulation_budget = self._resolve_simulation_budget() initial_state: dict[str, object] = { "node_counter": 0, "nodes": {"root": dict(root_node)}, "frontier_trace": [], "explored_nodes": 0, "best_node": dict(root_node), "simulations_run": 0, "terminated_reason": "simulation_budget_reached", "should_continue": True, } def _run_iteration(context: Mapping[str, object]) -> Mapping[str, object]: loop_state = context.get("loop_state") current_state = dict(loop_state) if isinstance(loop_state, Mapping) else {} nodes = _normalize_nodes(current_state.get("nodes")) if "root" not in nodes: nodes["root"] = dict(root_node) node_counter = _safe_int(current_state.get("node_counter")) explored_nodes = _safe_int(current_state.get("explored_nodes")) simulations_run = _safe_int(current_state.get("simulations_run")) frontier_trace = ( [dict(item) for item in current_state.get("frontier_trace", [])] if isinstance(current_state.get("frontier_trace"), list) else [] ) maybe_best_node = current_state.get("best_node") best_node = dict(maybe_best_node) if isinstance(maybe_best_node, Mapping) else dict(root_node) selected_path = _select_mcts_path( nodes=nodes, max_depth=self._max_depth, exploration_weight=self._mcts_exploration_weight, ) if not selected_path: return { **current_state, "should_continue": False, "terminated_reason": "no_expansions", "nodes": _json_ready(nodes), } selected_node_id = selected_path[-1] selected_node = nodes[selected_node_id] selected_depth = _safe_int(selected_node.get("depth")) rollout_score = _safe_float(selected_node.get("score")) backprop_path = list(selected_path) if selected_depth >= self._max_depth: selected_node["terminal"] = True else: children = self._generate_children( prompt=prompt, parent_node=selected_node, depth=(selected_depth + 1), request_id=request_id, dependencies=dependencies, ) expanded_child_ids: list[str] = [] for child in children[: self._branch_factor]: node_counter += 1 child_id = f"node_{node_counter}" child_candidate = _normalize_candidate(child) child_score = self._evaluate_candidate( prompt=prompt, candidate=child_candidate, depth=(selected_depth + 1), request_id=request_id, dependencies=dependencies, ) explored_nodes += 1 child_node = { "node_id": child_id, "candidate": child_candidate, "score": child_score, "depth": selected_depth + 1, "parent_id": selected_node_id, "visits": 0, "total_score": 0.0, "children": [], "expanded": False, "terminal": selected_depth + 1 >= self._max_depth, } nodes[child_id] = child_node expanded_child_ids.append(child_id) if expanded_child_ids: selected_node["children"] = expanded_child_ids selected_node["expanded"] = True chosen_child_id = max( expanded_child_ids, key=lambda node_id: ( _safe_float(nodes[node_id].get("score")), -_safe_int(nodes[node_id].get("depth")), str(node_id), ), ) rollout_score = _safe_float(nodes[chosen_child_id].get("score")) backprop_path.append(chosen_child_id) else: selected_node["expanded"] = True selected_node["terminal"] = True for node_id in backprop_path: node = nodes.get(node_id) if not isinstance(node, Mapping): continue visits = _safe_int(node.get("visits")) + 1 total_score = _safe_float(node.get("total_score")) + rollout_score mutable_node = dict(node) mutable_node["visits"] = visits mutable_node["total_score"] = total_score nodes[node_id] = mutable_node best_node = _select_best_node(nodes=nodes, fallback=best_node) simulations_run += 1 should_continue = simulations_run < simulation_budget and _has_expandable_nodes( nodes=nodes, max_depth=self._max_depth, ) terminated_reason = "looping" if not should_continue: terminated_reason = ( "simulation_budget_reached" if simulations_run >= simulation_budget else "no_expansions" ) frontier_trace.append( { "simulation": simulations_run, "selected_path": list(backprop_path), "top_nodes": _build_mcts_top_nodes(nodes=nodes, top_k=self._beam_width), } ) return { "node_counter": node_counter, "nodes": _json_ready(nodes), "frontier_trace": frontier_trace, "explored_nodes": explored_nodes, "best_node": _json_ready(best_node), "simulations_run": simulations_run, "should_continue": should_continue, "terminated_reason": terminated_reason, } wrapped_handler = wrap_iteration_handler( _run_iteration, error_prefix="TreeSearchPattern MCTS iteration", ) loop_callbacks = build_loop_callbacks( iteration_step_id="tree_search_iteration", iteration_handler=wrapped_handler, ) return loop_callbacks, initial_state, simulation_budget def _generate_children( self, *, prompt: str, parent_node: Mapping[str, object], depth: int, request_id: str, dependencies: Mapping[str, object], ) -> list[Mapping[str, object] | str | int | float]: """Generate child candidates from one parent node.""" delegate_input = { "task": prompt, "depth": depth, "parent": _json_ready(parent_node.get("candidate", {})), "parent_node": _json_ready(parent_node), "branch_factor": self._branch_factor, "search_strategy": self._search_strategy, } delegate_prompt = json.dumps(delegate_input, ensure_ascii=True, sort_keys=True) delegate_invocation = invoke_delegate( delegate=self._generator_delegate, prompt=delegate_prompt, step_context=None, request_id=f"{request_id}:tree_search:generator:{depth}:{parent_node.get('node_id')}", execution_mode="sequential", failure_policy="skip_dependents", dependencies=dependencies, ) delegate_result = delegate_invocation.result if not delegate_result.success: return [] return _extract_candidate_list(delegate_result.output) def _evaluate_candidate( self, *, prompt: str, candidate: Mapping[str, object], depth: int, request_id: str, dependencies: Mapping[str, object], ) -> float: """Evaluate one candidate and return normalized score.""" delegate_input = { "task": prompt, "depth": depth, "candidate": _json_ready(candidate), "search_strategy": self._search_strategy, } delegate_prompt = json.dumps(delegate_input, ensure_ascii=True, sort_keys=True) delegate_invocation = invoke_delegate( delegate=self._evaluator_delegate, prompt=delegate_prompt, step_context=None, request_id=f"{request_id}:tree_search:evaluator:{depth}", execution_mode="sequential", failure_policy="skip_dependents", dependencies=dependencies, ) delegate_result = delegate_invocation.result if not delegate_result.success: return 0.0 return _extract_score(delegate_result.output)
def _build_tree_search_result( *, workflow_result: ExecutionResult, request_id: str, dependencies: Mapping[str, object], max_depth: int, branch_factor: int, beam_width: int, search_strategy: SearchStrategy, mcts_exploration_weight: float, simulation_budget: int, root_node: Mapping[str, object], ) -> ExecutionResult: """Build the final tree-search result from one workflow execution.""" loop_step_result = workflow_result.step_results.get("tree_search_loop") if loop_step_result is None: raise RuntimeError("Tree search loop step result is missing.") loop_output = loop_step_result.output final_state_raw = loop_output.get("final_state") final_state = dict(final_state_raw) if isinstance(final_state_raw, Mapping) else {} best_node_raw = final_state.get("best_node") best_node = dict(best_node_raw) if isinstance(best_node_raw, Mapping) else dict(root_node) best_candidate_raw = best_node.get("candidate") best_candidate = ( dict(best_candidate_raw) if isinstance(best_candidate_raw, Mapping) else {"value": _json_ready(best_candidate_raw)} ) best_score = _safe_float(best_node.get("score")) raw_frontier_trace = final_state.get("frontier_trace") frontier_trace = ( [dict(item) for item in raw_frontier_trace if isinstance(item, Mapping)] if isinstance(raw_frontier_trace, list) else [] ) explored_nodes = _safe_int(final_state.get("explored_nodes")) simulations_run = _safe_int(final_state.get("simulations_run")) terminated_reason = str( final_state.get( "terminated_reason", loop_output.get("terminated_reason", "max_depth_reached"), ) ) success = bool(loop_output.get("success", loop_step_result.success)) details: dict[str, object] = { "search_strategy": search_strategy, "best_node": _json_ready(best_node), "explored_nodes": explored_nodes, "frontier_trace": [_json_ready(item) for item in frontier_trace], } if search_strategy == "mcts": details["simulations_run"] = simulations_run details["simulation_budget"] = simulation_budget return build_pattern_execution_result( success=success, final_output={ "best_candidate": _json_ready(best_candidate), "best_score": best_score, }, terminated_reason=terminated_reason, details=details, workflow_payload=workflow_result.to_dict(), artifacts=workflow_result.output.get("artifacts", []), request_id=request_id, dependencies=dependencies, mode=MODE_TREE_SEARCH, metadata={ "max_depth": max_depth, "branch_factor": branch_factor, "beam_width": beam_width, "search_strategy": search_strategy, "mcts_exploration_weight": mcts_exploration_weight, "simulation_budget": simulation_budget, }, error=loop_step_result.error, ) def _normalize_nodes(raw_nodes: object) -> dict[str, dict[str, object]]: """Normalize stored node mapping payload.""" if not isinstance(raw_nodes, Mapping): return {} normalized: dict[str, dict[str, object]] = {} for node_id, raw_node in raw_nodes.items(): if not isinstance(node_id, str) or not isinstance(raw_node, Mapping): continue normalized[node_id] = {str(key): _json_ready(value) for key, value in raw_node.items()} return normalized def _select_best_node( *, nodes: Mapping[str, Mapping[str, object]], fallback: Mapping[str, object], ) -> dict[str, object]: """Select best node by score using deterministic ties.""" if not nodes: return dict(fallback) best_node_id = max( nodes, key=lambda node_id: ( _safe_float(nodes[node_id].get("score")), -_safe_int(nodes[node_id].get("depth")), str(node_id), ), ) selected = nodes.get(best_node_id) if isinstance(selected, Mapping): return dict(selected) return dict(fallback) def _has_expandable_nodes(*, nodes: Mapping[str, Mapping[str, object]], max_depth: int) -> bool: """Return whether any node can still be expanded.""" for node in nodes.values(): depth = _safe_int(node.get("depth")) if depth >= max_depth: continue if bool(node.get("terminal", False)): continue if not bool(node.get("expanded", False)): return True child_ids = node.get("children") if isinstance(child_ids, list): for child_id in child_ids: child_node = nodes.get(str(child_id)) if isinstance(child_node, Mapping) and not bool(child_node.get("terminal", False)): return True return False def _select_mcts_path( *, nodes: Mapping[str, Mapping[str, object]], max_depth: int, exploration_weight: float, ) -> list[str]: """Select one MCTS traversal path with deterministic tie breaks.""" if "root" not in nodes: return [] path = ["root"] current_id = "root" visited: set[str] = {"root"} while True: current_node = nodes.get(current_id) if not isinstance(current_node, Mapping): return [] current_depth = _safe_int(current_node.get("depth")) if current_depth >= max_depth: return path expanded = bool(current_node.get("expanded", False)) raw_children = current_node.get("children") child_ids = [str(child_id) for child_id in raw_children] if isinstance(raw_children, list) else [] child_ids = [child_id for child_id in child_ids if child_id in nodes] if not expanded or not child_ids: return path parent_visits = max(1, _safe_int(current_node.get("visits"))) next_id = max( child_ids, key=lambda child_id: ( _mcts_ucb( node=nodes[child_id], parent_visits=parent_visits, exploration_weight=exploration_weight, ), -_safe_int(nodes[child_id].get("depth")), str(child_id), ), ) if next_id in visited: return path path.append(next_id) visited.add(next_id) current_id = next_id def _mcts_ucb(*, node: Mapping[str, object], parent_visits: int, exploration_weight: float) -> float: """Return deterministic UCB score for one node.""" visits = _safe_int(node.get("visits")) if visits <= 0: return float("inf") avg_value = _safe_float(node.get("total_score")) / visits exploration_term = exploration_weight * math.sqrt(math.log(parent_visits + 1) / visits) return avg_value + exploration_term def _build_mcts_top_nodes(*, nodes: Mapping[str, Mapping[str, object]], top_k: int) -> list[dict[str, object]]: """Return top-ranked nodes for trace snapshots.""" ranked = sorted( (dict(node) for node in nodes.values() if isinstance(node, Mapping) and str(node.get("node_id", ""))), key=lambda node: ( _safe_float(node.get("score")), _safe_int(node.get("visits")), -_safe_int(node.get("depth")), str(node.get("node_id", "")), ), reverse=True, ) selected = ranked[: max(1, top_k)] return [ { "node_id": str(node.get("node_id", "")), "score": _safe_float(node.get("score")), "visits": _safe_int(node.get("visits")), "depth": _safe_int(node.get("depth")), } for node in selected ] def _normalize_generator_delegate( delegate: GeneratorDelegate | DelegateTarget, ) -> DelegateTarget: """Normalize generator delegate into one object-delegate contract.""" if is_workflow_delegate(delegate): return delegate return GeneratorCallableDelegateAdapter(cast(GeneratorDelegate, delegate)) def _normalize_evaluator_delegate( delegate: EvaluatorDelegate | DelegateTarget, ) -> DelegateTarget: """Normalize evaluator delegate into one object-delegate contract.""" if is_workflow_delegate(delegate): return delegate return EvaluatorCallableDelegateAdapter(cast(EvaluatorDelegate, delegate)) def _is_agent_like(delegate: object) -> bool: """Return whether delegate appears to implement an agent-like ``run`` method.""" return is_workflow_delegate(delegate) def _extract_candidate_list(output: Mapping[str, object]) -> list[GeneratorValue]: """Extract candidate list from delegate output payload.""" candidates = output.get("candidates") if isinstance(candidates, list): return _coerce_generator_values(candidates) candidate = output.get("candidate") normalized_candidate = _normalize_generator_value(candidate) if normalized_candidate is not None: return [normalized_candidate] model_text = output.get("model_text") if isinstance(model_text, str): try: parsed = json.loads(model_text) except json.JSONDecodeError: return [] if isinstance(parsed, list): return _coerce_generator_values(parsed) if isinstance(parsed, Mapping): parsed_candidates = parsed.get("candidates") if isinstance(parsed_candidates, list): return _coerce_generator_values(parsed_candidates) parsed_candidate = parsed.get("candidate") normalized_parsed_candidate = _normalize_generator_value(parsed_candidate) if normalized_parsed_candidate is not None: return [normalized_parsed_candidate] return [] def _extract_score(output: Mapping[str, object]) -> float: """Extract numeric score from delegate output payload.""" score = output.get("score") if isinstance(score, (int, float)): return float(score) model_text = output.get("model_text") if isinstance(model_text, str): try: parsed = json.loads(model_text) except json.JSONDecodeError: return 0.0 if isinstance(parsed, Mapping): parsed_score = parsed.get("score") if isinstance(parsed_score, (int, float)): return float(parsed_score) return 0.0 def _normalize_candidate( candidate: Mapping[str, object] | str | int | float, ) -> dict[str, object]: """Normalize one candidate payload to a mapping.""" if isinstance(candidate, Mapping): return {str(key): _json_ready(value) for key, value in candidate.items()} return {"value": _json_ready(candidate)} def _normalize_generator_value(value: object) -> GeneratorValue | None: """Normalize raw candidate value to supported generator union.""" if isinstance(value, Mapping): return {str(key): _json_ready(item) for key, item in value.items()} if isinstance(value, (str, int, float)): return value return None def _coerce_generator_values(values: Sequence[object]) -> list[GeneratorValue]: """Coerce heterogeneous values to supported generator union values.""" normalized: list[GeneratorValue] = [] for value in values: normalized_value = _normalize_generator_value(value) if normalized_value is not None: normalized.append(normalized_value) return normalized def _safe_float(value: object) -> float: """Convert values to float with deterministic fallback to zero.""" if isinstance(value, bool): return float(int(value)) if isinstance(value, (int, float)): return float(value) if isinstance(value, str): try: return float(value.strip()) except ValueError: return 0.0 return 0.0 def _safe_int(value: object) -> int: """Convert values to int with deterministic fallback to zero.""" if isinstance(value, bool): return int(value) if isinstance(value, int): return value if isinstance(value, float): return int(value) if isinstance(value, str): try: return int(value.strip()) except ValueError: return 0 return 0 def _json_ready(value: object) -> object: """Recursively convert values into JSON-safe shapes.""" if isinstance(value, Mapping): return {str(key): _json_ready(item) for key, item in value.items()} if isinstance(value, list): return [_json_ready(item) for item in value] if isinstance(value, tuple): return [_json_ready(item) for item in value] if isinstance(value, (str, int, float, bool)) or value is None: return value return str(value) __all__ = ["TreeSearchPattern"]