Source code for design_research_agents._implementations._patterns._beam_search_pattern

"""Tree-search reasoning pattern with pluggable generator and evaluator delegates."""

from __future__ import annotations

import json
from collections.abc import Callable, Mapping, Sequence
from typing import TypeGuard, cast

from design_research_agents._contracts._delegate import Delegate, ExecutionResult
from design_research_agents._contracts._workflow import DelegateTarget, LogicStep, LoopStep
from design_research_agents._runtime._common._delegate_invocation import invoke_delegate
from design_research_agents._runtime._patterns import (
    MODE_BEAM_SEARCH,
    build_compiled_pattern_execution,
    build_pattern_execution_result,
    resolve_pattern_run_context,
)
from design_research_agents._tracing import Tracer
from design_research_agents.workflow import CompiledExecution, Workflow

GeneratorValue = Mapping[str, object] | str | int | float
GeneratorDelegate = Callable[[Mapping[str, object]], Sequence[GeneratorValue]]
EvaluatorDelegate = Callable[[Mapping[str, object]], float | int | Mapping[str, object]]


[docs] class BeamSearchPattern(Delegate): """Beam-style tree search over generated candidate states.""" def __init__( self, *, generator_delegate: GeneratorDelegate | DelegateTarget, evaluator_delegate: EvaluatorDelegate | DelegateTarget, max_depth: int = 3, branch_factor: int = 3, beam_width: int = 2, tracer: Tracer | None = None, ) -> None: """Initialize tree-search reasoning pattern. Args: generator_delegate: Delegate that expands one candidate into children. evaluator_delegate: Delegate that assigns a score to one candidate. max_depth: Maximum expansion depth. branch_factor: Max children retained per expanded node. beam_width: Max frontier width kept after each depth. tracer: Optional tracer dependency. Raises: ValueError: Raised when depth/branch/beam settings are invalid. """ if max_depth < 1: raise ValueError("max_depth must be >= 1.") if branch_factor < 1: raise ValueError("branch_factor must be >= 1.") if beam_width < 1: raise ValueError("beam_width must be >= 1.") self._generator_delegate = _normalize_generator_delegate(generator_delegate) self._evaluator_delegate = _normalize_evaluator_delegate(evaluator_delegate) self._max_depth = max_depth self._branch_factor = branch_factor self._beam_width = beam_width self._tracer = tracer self.workflow: object | None = None self._beam_search_runtime: dict[str, object] | None = None
[docs] def run( self, prompt: str, *, request_id: str | None = None, dependencies: Mapping[str, object] | None = None, ) -> ExecutionResult: """Execute tree search and return the highest-scoring candidate.""" return self.compile( prompt=prompt, request_id=request_id, dependencies=dependencies, ).run()
[docs] def compile( self, prompt: str, *, request_id: str | None = None, dependencies: Mapping[str, object] | None = None, ) -> CompiledExecution: """Compile one tree-search workflow.""" run_context = resolve_pattern_run_context( default_request_id_prefix=None, default_dependencies={}, request_id=request_id, dependencies=dependencies, ) input_payload = { "prompt": prompt, "mode": MODE_BEAM_SEARCH, "max_depth": self._max_depth, "branch_factor": self._branch_factor, "beam_width": self._beam_width, } workflow = self._build_workflow( prompt, request_id=run_context.request_id, dependencies=run_context.dependencies, ) runtime = self._beam_search_runtime or {} maybe_root_node = runtime.get("root_node") root_node = ( dict(maybe_root_node) if isinstance(maybe_root_node, Mapping) else { "node_id": "root", "candidate": {"text": prompt, "depth": 0}, "score": 0.0, "depth": 0, "parent_id": None, } ) return build_compiled_pattern_execution( workflow=workflow, pattern_name="BeamSearchPattern", request_id=run_context.request_id, dependencies=run_context.dependencies, tracer=self._tracer, input_payload=input_payload, workflow_request_id=f"{run_context.request_id}:beam_search_workflow", finalize=lambda workflow_result: _build_beam_search_result( workflow_result=workflow_result, request_id=run_context.request_id, dependencies=run_context.dependencies, max_depth=self._max_depth, branch_factor=self._branch_factor, beam_width=self._beam_width, root_node=root_node, ), )
def _build_workflow( self, prompt: str, *, request_id: str, dependencies: Mapping[str, object], ) -> Workflow: """Build the tree-search workflow for one resolved run context.""" root_candidate = {"text": prompt, "depth": 0} root_node = { "node_id": "root", "candidate": root_candidate, "score": 0.0, "depth": 0, "parent_id": None, } initial_state: dict[str, object] = { "node_counter": 0, "frontier": [root_node], "frontier_trace": [], "explored_nodes": 0, "best_node": root_node, "terminated_reason": "max_depth_reached", "should_continue": True, } def _continue_predicate(iteration: int, state: Mapping[str, object]) -> bool: """Decide whether another search iteration should run. Args: iteration: One-based loop iteration index. state: Current loop state mapping. Returns: ``True`` when search should continue. """ del iteration return bool(state.get("should_continue", True)) def _run_iteration(context: Mapping[str, object]) -> Mapping[str, object]: """Execute one expansion/evaluation iteration. Args: context: Workflow step context for the current iteration. Returns: Updated loop state for the next iteration. """ loop_meta = context.get("_loop") depth = 1 if isinstance(loop_meta, Mapping): depth = max(1, _safe_int(loop_meta.get("iteration", 1))) loop_state = context.get("loop_state") current_state = dict(loop_state) if isinstance(loop_state, Mapping) else {} raw_frontier = current_state.get("frontier") frontier = ( [dict(node) for node in raw_frontier if isinstance(node, Mapping)] if isinstance(raw_frontier, list) else [] ) node_counter = _safe_int(current_state.get("node_counter")) frontier_trace = ( [dict(item) for item in current_state.get("frontier_trace", [])] if isinstance(current_state.get("frontier_trace"), list) else [] ) explored_nodes = _safe_int(current_state.get("explored_nodes")) maybe_best_node = current_state.get("best_node") best_node = dict(maybe_best_node) if isinstance(maybe_best_node, Mapping) else root_node if not frontier: return { **current_state, "should_continue": False, "terminated_reason": "no_expansions", } expanded_nodes: list[dict[str, object]] = [] for parent_node in frontier: children = self._generate_children( prompt=prompt, parent_node=parent_node, depth=depth, request_id=request_id, dependencies=dependencies, ) for child in children[: self._branch_factor]: node_counter += 1 child_candidate = _normalize_candidate(child) child_score = self._evaluate_candidate( prompt=prompt, candidate=child_candidate, depth=depth, request_id=request_id, dependencies=dependencies, ) explored_nodes += 1 expanded_nodes.append( { "node_id": f"node_{node_counter}", "candidate": child_candidate, "score": child_score, "depth": depth, "parent_id": parent_node.get("node_id"), } ) if not expanded_nodes: return { **current_state, "node_counter": node_counter, "explored_nodes": explored_nodes, "should_continue": False, "terminated_reason": "no_expansions", } expanded_nodes.sort( key=lambda node: ( _safe_float(node.get("score")), -_safe_int(node.get("depth")), str(node.get("node_id", "")), ), reverse=True, ) frontier = expanded_nodes[: self._beam_width] frontier_trace.append( { "depth": depth, "frontier": [ { "node_id": node["node_id"], "score": node["score"], "candidate": _json_ready(node["candidate"]), "parent_id": node["parent_id"], } for node in frontier ], } ) if frontier and _safe_float(frontier[0].get("score")) >= _safe_float(best_node.get("score")): best_node = frontier[0] should_continue = depth < self._max_depth return { "node_counter": node_counter, "frontier": frontier, "frontier_trace": frontier_trace, "explored_nodes": explored_nodes, "best_node": best_node, "should_continue": should_continue, "terminated_reason": "looping" if should_continue else "max_depth_reached", } def _state_reducer( state: Mapping[str, object], iteration_result: ExecutionResult, iteration: int, ) -> Mapping[str, object]: """Fold one loop iteration output into accumulated search state. Args: state: Current accumulated loop state. iteration_result: Workflow result for one loop iteration. iteration: One-based loop iteration index. Returns: Reduced loop state mapping. """ del iteration iteration_step = iteration_result.step_results.get("beam_search_iteration") if iteration_step is None or not getattr(iteration_step, "success", False): return dict(state) output = getattr(iteration_step, "output", {}) return dict(output) if isinstance(output, Mapping) else dict(state) workflow = Workflow( tool_runtime=None, tracer=self._tracer, input_schema={"type": "object"}, steps=[ LoopStep( step_id="beam_search_loop", steps=( LogicStep( step_id="beam_search_iteration", handler=_run_iteration, ), ), max_iterations=self._max_depth, initial_state=initial_state, continue_predicate=_continue_predicate, state_reducer=_state_reducer, execution_mode="sequential", failure_policy="skip_dependents", ), ], ) self.workflow = workflow self._beam_search_runtime = {"root_node": dict(root_node)} return workflow def _run_beam_search( self, *, prompt: str, request_id: str, dependencies: Mapping[str, object], ) -> ExecutionResult: """Beam-style expansion and scoring through workflow loop primitives. Args: prompt: Task prompt to optimize through tree search. request_id: Resolved request id for this run. dependencies: Normalized dependency mapping passed to delegates. Returns: Final tree-search execution result. Raises: RuntimeError: Raised when the loop step result is missing. """ workflow = self._build_workflow( prompt, request_id=request_id, dependencies=dependencies, ) runtime = self._beam_search_runtime or {} maybe_root_node = runtime.get("root_node") root_node = ( dict(maybe_root_node) if isinstance(maybe_root_node, Mapping) else { "node_id": "root", "candidate": {"text": prompt, "depth": 0}, "score": 0.0, "depth": 0, "parent_id": None, } ) workflow_result = workflow.run( input={}, execution_mode="sequential", failure_policy="skip_dependents", request_id=f"{request_id}:beam_search_workflow", dependencies=dependencies, ) return _build_beam_search_result( workflow_result=workflow_result, request_id=request_id, dependencies=dependencies, max_depth=self._max_depth, branch_factor=self._branch_factor, beam_width=self._beam_width, root_node=root_node, ) def _generate_children( self, *, prompt: str, parent_node: Mapping[str, object], depth: int, request_id: str, dependencies: Mapping[str, object], ) -> list[Mapping[str, object] | str | int | float]: """Generate child candidates from one parent node. Args: prompt: Task prompt. parent_node: Parent frontier node payload. depth: One-based expansion depth. request_id: Resolved request identifier. dependencies: Normalized dependency mapping. Returns: Generated child candidates. """ delegate_input = { "task": prompt, "depth": depth, "parent": _json_ready(parent_node.get("candidate", {})), "parent_node": _json_ready(parent_node), "branch_factor": self._branch_factor, } delegate_prompt = json.dumps(delegate_input, ensure_ascii=True, sort_keys=True) delegate_invocation = invoke_delegate( delegate=self._generator_delegate, prompt=delegate_prompt, step_context=None, request_id=f"{request_id}:beam_search:generator:{depth}:{parent_node.get('node_id')}", execution_mode="sequential", failure_policy="skip_dependents", dependencies=dependencies, ) delegate_result = delegate_invocation.result if not delegate_result.success: return [] return _extract_candidate_list(delegate_result.output) def _evaluate_candidate( self, *, prompt: str, candidate: Mapping[str, object], depth: int, request_id: str, dependencies: Mapping[str, object], ) -> float: """Evaluate one candidate and return normalized score. Args: prompt: Task prompt. candidate: Candidate payload. depth: One-based candidate depth. request_id: Resolved request identifier. dependencies: Normalized dependency mapping. Returns: Candidate score. """ delegate_input = { "task": prompt, "depth": depth, "candidate": _json_ready(candidate), } delegate_prompt = json.dumps(delegate_input, ensure_ascii=True, sort_keys=True) delegate_invocation = invoke_delegate( delegate=self._evaluator_delegate, prompt=delegate_prompt, step_context=None, request_id=f"{request_id}:beam_search:evaluator:{depth}", execution_mode="sequential", failure_policy="skip_dependents", dependencies=dependencies, ) delegate_result = delegate_invocation.result if not delegate_result.success: return 0.0 return _extract_score(delegate_result.output)
def _build_beam_search_result( *, workflow_result: ExecutionResult, request_id: str, dependencies: Mapping[str, object], max_depth: int, branch_factor: int, beam_width: int, root_node: Mapping[str, object], ) -> ExecutionResult: """Build the final beam-search result from one workflow execution.""" loop_step_result = workflow_result.step_results.get("beam_search_loop") if loop_step_result is None: raise RuntimeError("Beam search loop step result is missing.") loop_output = loop_step_result.output final_state_raw = loop_output.get("final_state") final_state = dict(final_state_raw) if isinstance(final_state_raw, Mapping) else {} best_node_raw = final_state.get("best_node") best_node = dict(best_node_raw) if isinstance(best_node_raw, Mapping) else dict(root_node) best_candidate_raw = best_node.get("candidate") best_candidate = ( dict(best_candidate_raw) if isinstance(best_candidate_raw, Mapping) else {"value": _json_ready(best_candidate_raw)} ) best_score = _safe_float(best_node.get("score")) raw_frontier_trace = final_state.get("frontier_trace") frontier_trace = ( [dict(item) for item in raw_frontier_trace if isinstance(item, Mapping)] if isinstance(raw_frontier_trace, list) else [] ) explored_nodes = _safe_int(final_state.get("explored_nodes")) terminated_reason = str( final_state.get( "terminated_reason", loop_output.get("terminated_reason", "max_depth_reached"), ) ) success = bool(loop_output.get("success", loop_step_result.success)) return build_pattern_execution_result( success=success, final_output={ "best_candidate": _json_ready(best_candidate), "best_score": best_score, }, terminated_reason=terminated_reason, details={ "best_node": _json_ready(best_node), "explored_nodes": explored_nodes, "frontier_trace": [_json_ready(item) for item in frontier_trace], }, workflow_payload=workflow_result.to_dict(), artifacts=workflow_result.output.get("artifacts", []), request_id=request_id, dependencies=dependencies, mode=MODE_BEAM_SEARCH, metadata={ "max_depth": max_depth, "branch_factor": branch_factor, "beam_width": beam_width, }, error=loop_step_result.error, ) def _normalize_generator_delegate( delegate: GeneratorDelegate | DelegateTarget, ) -> DelegateTarget: """Normalize generator delegate into one object-delegate contract. Args: delegate: Generator callable or workflow-compatible delegate object. Returns: Workflow-compatible delegate object. """ if _is_workflow_delegate(delegate): return delegate return _GeneratorCallableDelegateAdapter(cast(GeneratorDelegate, delegate)) def _normalize_evaluator_delegate( delegate: EvaluatorDelegate | DelegateTarget, ) -> DelegateTarget: """Normalize evaluator delegate into one object-delegate contract. Args: delegate: Evaluator callable or workflow-compatible delegate object. Returns: Workflow-compatible delegate object. """ if _is_workflow_delegate(delegate): return delegate return _EvaluatorCallableDelegateAdapter(cast(EvaluatorDelegate, delegate)) class _GeneratorCallableDelegateAdapter: """Adapter that wraps callable generator delegates as agent-like delegates.""" def __init__(self, delegate: GeneratorDelegate) -> None: """Store callable generator delegate. Args: delegate: Callable generator delegate. """ self._delegate = delegate def run( self, *, context: Mapping[str, object] | None = None, execution_mode: str = "dag", failure_policy: str = "skip_dependents", request_id: str | None = None, dependencies: Mapping[str, object] | None = None, ) -> ExecutionResult: """Execute callable generator delegate and normalize output. Args: context: Optional delegate-runner context mapping. execution_mode: Unused workflow execution mode passthrough. failure_policy: Unused workflow failure policy passthrough. request_id: Optional request identifier. dependencies: Optional dependency mapping. Returns: Normalized execution result containing ``candidates`` output. """ del execution_mode, failure_policy, request_id, dependencies prompt = "" if isinstance(context, Mapping): raw_prompt = context.get("prompt") prompt = raw_prompt if isinstance(raw_prompt, str) else "" parsed_context = _parse_json_context(prompt) try: raw_children = self._delegate(parsed_context) except Exception as exc: return ExecutionResult( output={"error": str(exc)}, success=False, tool_results=[], model_response=None, metadata={"delegate_type": "callable_generator"}, ) return ExecutionResult( output={"candidates": list(raw_children)}, success=True, tool_results=[], model_response=None, metadata={"delegate_type": "callable_generator"}, ) class _EvaluatorCallableDelegateAdapter: """Adapter that wraps callable evaluator delegates as agent-like delegates.""" def __init__(self, delegate: EvaluatorDelegate) -> None: """Store callable evaluator delegate. Args: delegate: Callable evaluator delegate. """ self._delegate = delegate def run( self, *, context: Mapping[str, object] | None = None, execution_mode: str = "dag", failure_policy: str = "skip_dependents", request_id: str | None = None, dependencies: Mapping[str, object] | None = None, ) -> ExecutionResult: """Execute callable evaluator delegate and normalize score output. Args: context: Optional delegate-runner context mapping. execution_mode: Unused workflow execution mode passthrough. failure_policy: Unused workflow failure policy passthrough. request_id: Optional request identifier. dependencies: Optional dependency mapping. Returns: Normalized execution result containing a numeric score payload. """ del execution_mode, failure_policy, request_id, dependencies prompt = "" if isinstance(context, Mapping): raw_prompt = context.get("prompt") prompt = raw_prompt if isinstance(raw_prompt, str) else "" parsed_context = _parse_json_context(prompt) try: raw_score = self._delegate(parsed_context) except Exception as exc: return ExecutionResult( output={"error": str(exc)}, success=False, tool_results=[], model_response=None, metadata={"delegate_type": "callable_evaluator"}, ) if isinstance(raw_score, (int, float)): output: dict[str, object] = {"score": float(raw_score)} elif isinstance(raw_score, Mapping): output = dict(raw_score) else: output = {"score": 0.0} return ExecutionResult( output=output, success=True, tool_results=[], model_response=None, metadata={"delegate_type": "callable_evaluator"}, ) def _parse_json_context(prompt: str) -> dict[str, object]: """Parse delegate prompt payload into mapping context. Args: prompt: Serialized JSON payload or raw prompt text. Returns: Mapping context for callable delegate invocation. """ try: parsed = json.loads(prompt) except json.JSONDecodeError: return {"task": prompt} if isinstance(parsed, Mapping): return {str(key): _json_ready(value) for key, value in parsed.items()} return {"task": prompt} def _is_workflow_delegate(delegate: object) -> TypeGuard[DelegateTarget]: """Return whether delegate appears to implement workflow-delegate contract. Args: delegate: Delegate candidate object. Returns: ``True`` when ``delegate`` exposes a callable ``run`` method. """ run_callable = getattr(delegate, "run", None) return callable(run_callable) def _is_agent_like(delegate: object) -> bool: """Return whether delegate appears to implement an agent-like ``run`` method. Args: delegate: Delegate candidate object. Returns: ``True`` when the delegate has a callable ``run`` attribute. """ return _is_workflow_delegate(delegate) def _extract_candidate_list(output: Mapping[str, object]) -> list[GeneratorValue]: """Extract candidate list from delegate output payload. Args: output: Delegate output mapping. Returns: Normalized candidate list. """ candidates = output.get("candidates") if isinstance(candidates, list): return _coerce_generator_values(candidates) candidate = output.get("candidate") normalized_candidate = _normalize_generator_value(candidate) if normalized_candidate is not None: return [normalized_candidate] model_text = output.get("model_text") if isinstance(model_text, str): try: parsed = json.loads(model_text) except json.JSONDecodeError: return [] if isinstance(parsed, list): return _coerce_generator_values(parsed) if isinstance(parsed, Mapping): parsed_candidates = parsed.get("candidates") if isinstance(parsed_candidates, list): return _coerce_generator_values(parsed_candidates) parsed_candidate = parsed.get("candidate") normalized_parsed_candidate = _normalize_generator_value(parsed_candidate) if normalized_parsed_candidate is not None: return [normalized_parsed_candidate] return [] def _extract_score(output: Mapping[str, object]) -> float: """Extract numeric score from delegate output payload. Args: output: Delegate output mapping. Returns: Extracted numeric score. """ score = output.get("score") if isinstance(score, (int, float)): return float(score) model_text = output.get("model_text") if isinstance(model_text, str): try: parsed = json.loads(model_text) except json.JSONDecodeError: return 0.0 if isinstance(parsed, Mapping): parsed_score = parsed.get("score") if isinstance(parsed_score, (int, float)): return float(parsed_score) return 0.0 def _normalize_candidate( candidate: Mapping[str, object] | str | int | float, ) -> dict[str, object]: """Normalize one candidate payload to a mapping. Args: candidate: Candidate payload. Returns: Mapping-form candidate payload. """ if isinstance(candidate, Mapping): return {str(key): _json_ready(value) for key, value in candidate.items()} return {"value": _json_ready(candidate)} def _normalize_generator_value(value: object) -> GeneratorValue | None: """Normalize raw candidate value to supported generator union. Args: value: Raw candidate value. Returns: Normalized value when supported, otherwise ``None``. """ if isinstance(value, Mapping): return {str(key): _json_ready(item) for key, item in value.items()} if isinstance(value, (str, int, float)): return value return None def _coerce_generator_values(values: Sequence[object]) -> list[GeneratorValue]: """Coerce heterogeneous values to supported generator union values. Args: values: Raw candidate values. Returns: Supported candidate values only. """ normalized: list[GeneratorValue] = [] for value in values: normalized_value = _normalize_generator_value(value) if normalized_value is not None: normalized.append(normalized_value) return normalized def _safe_float(value: object) -> float: """Convert values to float with deterministic fallback to zero. Args: value: Raw input value. Returns: Float representation or ``0.0`` fallback. """ if isinstance(value, bool): return float(int(value)) if isinstance(value, (int, float)): return float(value) if isinstance(value, str): try: return float(value.strip()) except ValueError: return 0.0 return 0.0 def _safe_int(value: object) -> int: """Convert values to int with deterministic fallback to zero. Args: value: Raw input value. Returns: Integer representation or ``0`` fallback. """ if isinstance(value, bool): return int(value) if isinstance(value, int): return value if isinstance(value, float): return int(value) if isinstance(value, str): try: return int(value.strip()) except ValueError: return 0 return 0 def _json_ready(value: object) -> object: """Recursively convert values into JSON-safe shapes. Args: value: Raw input value. Returns: JSON-safe representation. """ if isinstance(value, Mapping): return {str(key): _json_ready(item) for key, item in value.items()} if isinstance(value, list): return [_json_ready(item) for item in value] if isinstance(value, tuple): return [_json_ready(item) for item in value] if isinstance(value, (str, int, float, bool)) or value is None: return value return str(value) __all__ = ["BeamSearchPattern"]