"""Unified multi-step agent facade with explicit mode selection."""
from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import Literal
from design_research_agents._contracts._delegate import Delegate, ExecutionResult
from design_research_agents._contracts._llm import LLMClient
from design_research_agents._contracts._memory import MemoryStore
from design_research_agents._contracts._tools import ToolRuntime
from design_research_agents._implementations._shared._agent_internal._multi_step_modes._code import (
MultiStepCodeToolCallingAgent as _CodeModeStrategy,
)
from design_research_agents._implementations._shared._agent_internal._multi_step_modes._direct import (
MultiStepDirectLLMAgent as _DirectModeStrategy,
)
from design_research_agents._implementations._shared._agent_internal._multi_step_modes._direct import (
_coerce_state_records,
_parse_controller_decision,
)
from design_research_agents._implementations._shared._agent_internal._multi_step_modes._json import (
MultiStepJsonToolCallingAgent as _JsonModeStrategy,
)
from design_research_agents._implementations._shared._agent_internal._prompt_alternatives import (
AlternativesPromptTarget,
)
from design_research_agents._tracing import Tracer
from design_research_agents.workflow import CompiledExecution
type MultiStepMode = Literal["direct", "json", "code"]
_FINAL_ANSWER_TOOL_NAME = "final_answer"
def _normalize_mode(raw_mode: object) -> MultiStepMode:
"""Normalize one mode input to the closed set of supported strategy names."""
normalized_mode = raw_mode.strip().lower() if isinstance(raw_mode, str) else ""
if normalized_mode == "direct":
return "direct"
if normalized_mode == "json":
return "json"
if normalized_mode == "code":
return "code"
raise ValueError("mode must be one of: 'direct', 'json', 'code'.")
[docs]
class MultiStepAgent(Delegate):
"""Single multi-step runtime entrypoint for direct/json/code strategies."""
def __init__(
self,
*,
mode: MultiStepMode,
llm_client: LLMClient,
tool_runtime: ToolRuntime | None = None,
max_steps: int = 5,
stop_on_step_failure: bool = True,
controller_system_prompt: str | None = None,
controller_user_prompt_template: str | None = None,
continuation_system_prompt: str | None = None,
continuation_user_prompt_template: str | None = None,
step_user_prompt_template: str | None = None,
tool_calling_system_prompt: str | None = None,
tool_calling_user_prompt_template: str | None = None,
alternatives_prompt_target: AlternativesPromptTarget = "user",
continuation_memory_tail_items: int = 6,
step_memory_tail_items: int = 8,
memory_store: MemoryStore | None = None,
memory_namespace: str = "default",
memory_read_top_k: int = 4,
memory_write_observations: bool = True,
max_tool_calls_per_step: int = 5,
execution_timeout_seconds: int = 5,
validate_tool_input_schema: bool = False,
normalize_generated_code_per_step: bool = False,
default_tools_per_step: Sequence[Mapping[str, object]] | None = None,
allowed_tools: Sequence[str] | None = None,
tracer: Tracer | None = None,
) -> None:
"""Initialize one mode-specific multi-step strategy.
Args:
mode: Required strategy mode (``direct``, ``json``, or ``code``).
llm_client: LLM client shared by all strategy modes.
tool_runtime: Tool runtime required for ``json`` and ``code`` modes.
max_steps: Maximum number of multi-step iterations.
stop_on_step_failure: Whether to stop loop execution on failed steps.
controller_system_prompt: Direct-mode controller system prompt override.
controller_user_prompt_template: Direct-mode controller user prompt override.
continuation_system_prompt: Continuation system prompt override.
continuation_user_prompt_template: Continuation user prompt override.
step_user_prompt_template: Step action user prompt override.
tool_calling_system_prompt: Json mode tool-calling system prompt override.
tool_calling_user_prompt_template: Json mode tool-calling user prompt override.
alternatives_prompt_target: Prompt insertion target for alternatives blocks.
continuation_memory_tail_items: Continuation memory tail item count.
step_memory_tail_items: Step memory tail item count.
memory_store: Optional persistent memory dependency.
memory_namespace: Memory namespace for read/write operations.
memory_read_top_k: Memory retrieval top-k.
memory_write_observations: Whether to persist per-step observations.
max_tool_calls_per_step: Code-mode per-step tool call cap.
execution_timeout_seconds: Code-mode sandbox timeout.
validate_tool_input_schema: Code-mode tool input schema validation toggle.
normalize_generated_code_per_step: Code-mode code normalization toggle.
default_tools_per_step: Code-mode default tool allowlist.
allowed_tools: Optional json-mode tool allowlist.
tracer: Optional tracer dependency.
Raises:
ValueError: Raised when mode/tool configuration is invalid.
"""
# Coerce and validate mode argument
normalized_mode = _normalize_mode(mode)
# Validate tool runtime presence for json/code modes
if normalized_mode in {"json", "code"}:
if tool_runtime is None:
raise ValueError("tool_runtime is required when mode is 'json' or 'code'.")
runtime_tools = tuple(tool_runtime.list_tools())
if not runtime_tools:
raise ValueError("tool_runtime must expose at least one tool when mode is 'json' or 'code'.")
if any(tool.name == _FINAL_ANSWER_TOOL_NAME for tool in runtime_tools):
raise ValueError(f"tool_runtime cannot expose reserved tool name '{_FINAL_ANSWER_TOOL_NAME}'.")
# Additional validation for code mode
self._mode = normalized_mode
self._strategy: Delegate
if self._mode == "direct":
self._strategy = _DirectModeStrategy(
llm_client=llm_client,
max_steps=max_steps,
controller_system_prompt=controller_system_prompt,
controller_user_prompt_template=controller_user_prompt_template,
step_memory_tail_items=step_memory_tail_items,
tracer=tracer,
)
return
# Tool runtime presence and basic tool availability have already been validated at this point,
# so we can safely assert here for type checking purposes.
assert tool_runtime is not None
if self._mode == "code":
self._strategy = _CodeModeStrategy(
llm_client=llm_client,
tool_runtime=tool_runtime,
max_steps=max_steps,
max_tool_calls_per_step=max_tool_calls_per_step,
execution_timeout_seconds=execution_timeout_seconds,
validate_tool_input_schema=validate_tool_input_schema,
normalize_generated_code_per_step=normalize_generated_code_per_step,
stop_on_step_failure=stop_on_step_failure,
default_tools_per_step=default_tools_per_step,
continuation_system_prompt=continuation_system_prompt,
continuation_user_prompt_template=continuation_user_prompt_template,
step_user_prompt_template=step_user_prompt_template,
alternatives_prompt_target=alternatives_prompt_target,
continuation_memory_tail_items=continuation_memory_tail_items,
step_memory_tail_items=step_memory_tail_items,
memory_store=memory_store,
memory_namespace=memory_namespace,
memory_read_top_k=memory_read_top_k,
memory_write_observations=memory_write_observations,
tracer=tracer,
)
return
# If we reach this point, the mode must be "json" due to the earlier validation,
# so we can safely assert here for type checking purposes.
self._strategy = _JsonModeStrategy(
llm_client=llm_client,
tool_runtime=tool_runtime,
max_steps=max_steps,
stop_on_step_failure=stop_on_step_failure,
continuation_system_prompt=continuation_system_prompt,
continuation_user_prompt_template=continuation_user_prompt_template,
step_user_prompt_template=step_user_prompt_template,
tool_calling_system_prompt=tool_calling_system_prompt,
tool_calling_user_prompt_template=tool_calling_user_prompt_template,
alternatives_prompt_target=alternatives_prompt_target,
continuation_memory_tail_items=continuation_memory_tail_items,
step_memory_tail_items=step_memory_tail_items,
memory_store=memory_store,
memory_namespace=memory_namespace,
memory_read_top_k=memory_read_top_k,
memory_write_observations=memory_write_observations,
allowed_tools=allowed_tools,
tracer=tracer,
)
@property
def workflow(self) -> object | None:
"""Expose the most recently compiled workflow from the selected strategy."""
return getattr(self._strategy, "workflow", None)
[docs]
def compile(
self,
prompt: str,
*,
request_id: str | None = None,
dependencies: Mapping[str, object] | None = None,
) -> CompiledExecution:
"""Compile one run through the selected strategy mode."""
compile_callable = getattr(self._strategy, "compile", None)
if not callable(compile_callable):
raise TypeError("Selected multi-step strategy does not implement compile().")
compiled_execution = compile_callable(
prompt,
request_id=request_id,
dependencies=dependencies,
)
if not isinstance(compiled_execution, CompiledExecution):
raise TypeError("Selected multi-step strategy compile() must return CompiledExecution.")
return compiled_execution
[docs]
def run(
self,
prompt: str,
*,
request_id: str | None = None,
dependencies: Mapping[str, object] | None = None,
) -> ExecutionResult:
"""Execute one run through the selected strategy mode."""
return self.compile(
prompt,
request_id=request_id,
dependencies=dependencies,
).run()
__all__ = [
"MultiStepAgent",
"MultiStepMode",
# Re-exported for test and transition compatibility with existing imports.
"_coerce_state_records",
"_parse_controller_decision",
]