"""Reusable ``propose_critic`` orchestration chunk."""
from __future__ import annotations
from collections.abc import Mapping
from design_research_agents._contracts._delegate import Delegate, ExecutionResult
from design_research_agents._contracts._llm import LLMClient
from design_research_agents._contracts._tools import ToolRuntime
from design_research_agents._contracts._workflow import (
DelegateStep,
DelegateTarget,
LogicStep,
LoopStep,
ModelStep,
WorkflowStep,
)
from design_research_agents._implementations._agents._direct_llm_call import (
DirectLLMCall,
)
from design_research_agents._implementations._shared._agent_internal._input_parsing import (
extract_prompt as _extract_prompt,
)
from design_research_agents._implementations._shared._agent_internal._model_resolution import (
resolve_agent_model,
)
from design_research_agents._implementations._shared._agent_internal._run_options import (
normalize_input_payload,
)
from design_research_agents._implementations._shared._workflow_internal._propose_critic_helpers import (
DEFAULT_CRITIC_SYSTEM_PROMPT,
DEFAULT_CRITIC_USER_PROMPT_TEMPLATE,
DEFAULT_PROPOSER_SYSTEM_PROMPT,
DEFAULT_PROPOSER_USER_PROMPT_TEMPLATE,
ProposeCriticLoopCallbacks,
)
from design_research_agents._runtime._patterns import (
MODE_PROPOSE_CRITIC,
WorkflowBudgetTracker,
attach_runtime_metadata,
build_compiled_pattern_execution,
build_pattern_execution_result,
normalize_mapping,
normalize_mapping_records,
normalize_request_id_prefix,
resolve_pattern_run_context,
resolve_prompt_override,
)
from design_research_agents._tracing import Tracer
from design_research_agents.workflow import CompiledExecution
from design_research_agents.workflow.workflow import Workflow
[docs]
class ProposeCriticPattern(Delegate):
"""Propose/critique revision pattern built on workflow primitives."""
def __init__(
self,
*,
llm_client: LLMClient,
tool_runtime: ToolRuntime,
proposer_delegate: DelegateTarget | None = None,
critic_delegate: DelegateTarget | None = None,
max_iterations: int = 3,
proposer_system_prompt: str | None = None,
proposer_user_prompt_template: str | None = None,
critic_system_prompt: str | None = None,
critic_user_prompt_template: str | None = None,
default_request_id_prefix: str | None = None,
default_dependencies: Mapping[str, object] | None = None,
tracer: Tracer | None = None,
) -> None:
"""Store dependencies and initialize workflow-native orchestration settings.
Args:
llm_client: LLM client used by proposer and critic calls.
tool_runtime: Tool runtime used by loop execution runtime.
proposer_delegate: Optional proposer delegate override.
critic_delegate: Optional critic delegate override.
max_iterations: Maximum propose/critic iterations per run.
proposer_system_prompt: Optional override for proposer system prompt.
proposer_user_prompt_template: Optional proposer user prompt template.
critic_system_prompt: Optional override for critic system prompt.
critic_user_prompt_template: Optional critic user prompt template.
default_request_id_prefix: Optional prefix used to derive request ids.
default_dependencies: Dependency defaults merged into each run.
tracer: Optional tracer used for run-level instrumentation.
Raises:
ValueError: If ``max_iterations`` is invalid.
"""
if max_iterations < 1:
raise ValueError("max_iterations must be >= 1.")
self._llm_client = llm_client
self._tool_runtime = tool_runtime
self._proposer_delegate = proposer_delegate
self._critic_delegate = critic_delegate
self._max_iterations = max_iterations
self._tracer = tracer
self.workflow: Workflow | None = None
self._default_request_id_prefix = normalize_request_id_prefix(default_request_id_prefix)
self._default_dependencies = dict(default_dependencies or {})
self._proposer_system_prompt = resolve_prompt_override(
override=proposer_system_prompt,
default_value=DEFAULT_PROPOSER_SYSTEM_PROMPT,
field_name="proposer_system_prompt",
)
self._proposer_user_prompt_template = resolve_prompt_override(
override=proposer_user_prompt_template,
default_value=DEFAULT_PROPOSER_USER_PROMPT_TEMPLATE,
field_name="proposer_user_prompt_template",
)
self._critic_system_prompt = resolve_prompt_override(
override=critic_system_prompt,
default_value=DEFAULT_CRITIC_SYSTEM_PROMPT,
field_name="critic_system_prompt",
)
self._critic_user_prompt_template = resolve_prompt_override(
override=critic_user_prompt_template,
default_value=DEFAULT_CRITIC_USER_PROMPT_TEMPLATE,
field_name="critic_user_prompt_template",
)
self._propose_critic_runtime: dict[str, object] | None = None
[docs]
def run(
self,
prompt: str,
*,
request_id: str | None = None,
dependencies: Mapping[str, object] | None = None,
) -> ExecutionResult:
"""Execute one propose-and-critique orchestration run."""
return self.compile(
prompt=prompt,
request_id=request_id,
dependencies=dependencies,
).run()
[docs]
def compile(
self,
prompt: str,
*,
request_id: str | None = None,
dependencies: Mapping[str, object] | None = None,
) -> CompiledExecution:
"""Compile one propose/critic workflow."""
run_context = resolve_pattern_run_context(
default_request_id_prefix=self._default_request_id_prefix,
default_dependencies=self._default_dependencies,
request_id=request_id,
dependencies=dependencies,
)
normalized_input = normalize_input_payload(prompt)
resolved_prompt = _extract_prompt(normalized_input)
workflow = self._build_workflow(
resolved_prompt,
request_id=run_context.request_id,
dependencies=run_context.dependencies,
)
runtime = self._propose_critic_runtime or {}
callbacks = runtime.get("callbacks")
budget_tracker = runtime.get("budget_tracker")
if not isinstance(callbacks, ProposeCriticLoopCallbacks) or not isinstance(
budget_tracker, WorkflowBudgetTracker
):
raise RuntimeError("Propose critic runtime state is unavailable before workflow execution.")
return build_compiled_pattern_execution(
workflow=workflow,
pattern_name="ProposeCriticPattern",
request_id=run_context.request_id,
dependencies=run_context.dependencies,
tracer=self._tracer,
input_payload={"prompt": resolved_prompt, "mode": MODE_PROPOSE_CRITIC},
workflow_request_id=f"{run_context.request_id}:propose_critic_loop",
finalize=lambda workflow_result: _finalize_propose_critic_result(
workflow_result=workflow_result,
callbacks=callbacks,
budget_tracker=budget_tracker,
request_id=run_context.request_id,
dependencies=run_context.dependencies,
max_iterations=self._max_iterations,
),
)
def _build_workflow(
self,
prompt: str,
*,
request_id: str,
dependencies: Mapping[str, object],
) -> Workflow:
"""Build the propose/critic workflow for one resolved run context."""
budget_tracker = WorkflowBudgetTracker()
resolved_model = resolve_agent_model(llm_client=self._llm_client)
proposer = self._proposer_delegate
if proposer is None:
proposer = DirectLLMCall(
llm_client=self._llm_client,
system_prompt=self._proposer_system_prompt,
tracer=self._tracer,
)
callbacks = ProposeCriticLoopCallbacks(
resolved_model=resolved_model,
task_prompt=prompt,
request_id=request_id,
dependencies=dependencies,
proposer_user_prompt_template=self._proposer_user_prompt_template,
critic_system_prompt=self._critic_system_prompt,
critic_user_prompt_template=self._critic_user_prompt_template,
budget_tracker=budget_tracker,
)
loop_steps: tuple[WorkflowStep, ...]
if self._critic_delegate is None:
loop_steps = (
DelegateStep(
step_id="propose_critic_proposer",
delegate=proposer,
prompt_builder=callbacks.build_proposer_prompt,
),
ModelStep(
step_id="propose_critic_critic_model",
dependencies=("propose_critic_proposer",),
llm_client=self._llm_client,
request_builder=callbacks.build_critic_request,
response_parser=callbacks.parse_critic_model_response,
),
LogicStep(
step_id="propose_critic_iteration",
dependencies=(
"propose_critic_proposer",
"propose_critic_critic_model",
),
handler=callbacks.build_iteration_from_model,
),
)
else:
loop_steps = (
DelegateStep(
step_id="propose_critic_proposer",
delegate=proposer,
prompt_builder=callbacks.build_proposer_prompt,
),
DelegateStep(
step_id="propose_critic_critic_delegate",
dependencies=("propose_critic_proposer",),
delegate=self._critic_delegate,
prompt_builder=callbacks.build_critic_prompt,
),
LogicStep(
step_id="propose_critic_iteration",
dependencies=(
"propose_critic_proposer",
"propose_critic_critic_delegate",
),
handler=callbacks.build_iteration_from_delegate,
),
)
workflow = Workflow(
tool_runtime=self._tool_runtime,
tracer=self._tracer,
input_schema={"type": "object"},
base_context={"prompt": prompt},
steps=[
LoopStep(
step_id="propose_critic_loop",
steps=loop_steps,
max_iterations=self._max_iterations,
initial_state={
"proposal": "",
"approved": False,
"feedback": "",
"revision_goals": [],
"failure_reason": None,
"failure_error": None,
"critique_iterations": [],
},
continue_predicate=callbacks.continue_predicate,
state_reducer=callbacks.state_reducer,
execution_mode="sequential",
failure_policy="propagate_failed_state",
)
],
)
self.workflow = workflow
self._propose_critic_runtime = {
"budget_tracker": budget_tracker,
"callbacks": callbacks,
}
return workflow
def _run_propose_critic(
self,
*,
prompt: str,
request_id: str,
dependencies: Mapping[str, object],
) -> ExecutionResult:
"""Propose/critic loop until approval or termination.
Args:
prompt: Task prompt to iteratively refine.
request_id: Resolved request id for this orchestration run.
dependencies: Normalized dependency mapping for this run.
Returns:
Final reflexion pattern result.
Raises:
RuntimeError: If loop execution fails irrecoverably.
"""
workflow = self._build_workflow(
prompt,
request_id=request_id,
dependencies=dependencies,
)
runtime = self._propose_critic_runtime or {}
callbacks = runtime.get("callbacks")
budget_tracker = runtime.get("budget_tracker")
if not isinstance(callbacks, ProposeCriticLoopCallbacks) or not isinstance(
budget_tracker, WorkflowBudgetTracker
):
raise RuntimeError("Propose critic runtime state is unavailable before workflow execution.")
workflow_result = workflow.run(
input={},
execution_mode="sequential",
failure_policy="skip_dependents",
request_id=f"{request_id}:propose_critic_loop",
dependencies=dependencies,
)
return _finalize_propose_critic_result(
workflow_result=workflow_result,
callbacks=callbacks,
budget_tracker=budget_tracker,
request_id=request_id,
dependencies=dependencies,
max_iterations=self._max_iterations,
)
def _finalize_propose_critic_result(
*,
workflow_result: ExecutionResult,
callbacks: ProposeCriticLoopCallbacks,
budget_tracker: WorkflowBudgetTracker,
request_id: str,
dependencies: Mapping[str, object],
max_iterations: int,
) -> ExecutionResult:
"""Build final propose/critic result from a workflow execution."""
loop_step_result = workflow_result.step_results.get("propose_critic_loop")
if loop_step_result is None:
raise RuntimeError("Propose and critique loop step result is missing.")
loop_output = loop_step_result.output
final_state = normalize_mapping(loop_output.get("final_state"))
critique_iterations = normalize_mapping_records(final_state.get("critique_iterations"))
current_proposal = str(final_state.get("proposal", ""))
approved = bool(final_state.get("approved"))
failure_reason_raw = final_state.get("failure_reason")
failure_reason = str(failure_reason_raw) if isinstance(failure_reason_raw, str) and failure_reason_raw else None
failure_error_raw = final_state.get("failure_error")
failure_error = str(failure_error_raw) if isinstance(failure_error_raw, str) and failure_error_raw else None
loop_terminated_reason = str(loop_output.get("terminated_reason", "max_iterations_reached"))
if approved:
terminated_reason = "approved"
elif failure_reason is not None:
terminated_reason = failure_reason
else:
terminated_reason = "max_iterations_reached"
if loop_terminated_reason == "iteration_failed" or failure_reason == "iteration_failed":
error_message = failure_error or "Workflow loop iteration failed."
raise RuntimeError(error_message)
details = {
"proposal": current_proposal,
"approved": approved,
"critique_iterations": critique_iterations,
}
final_output = {
"proposal": current_proposal,
"approved": approved,
"iterations": len(critique_iterations),
}
if failure_reason in {"critic_invalid_json", "critic_invalid_schema"}:
failure = build_pattern_execution_result(
success=False,
final_output=final_output,
terminated_reason=failure_reason,
details=details,
workflow_payload=workflow_result.to_dict(),
artifacts=workflow_result.output.get("artifacts", []),
request_id=request_id,
dependencies=dependencies,
mode=MODE_PROPOSE_CRITIC,
metadata={"stage": "critic", "iterations": len(critique_iterations)},
tool_results=[],
model_response=callbacks.last_model_response,
error=failure_error or "Critic iteration failed.",
)
return attach_runtime_metadata(
agent_result=failure,
requested_mode=MODE_PROPOSE_CRITIC,
resolved_mode=MODE_PROPOSE_CRITIC,
budget_metadata=budget_tracker.as_metadata(),
extra_metadata=None,
)
result = build_pattern_execution_result(
success=approved,
final_output=final_output,
terminated_reason=terminated_reason,
details=details,
workflow_payload=workflow_result.to_dict(),
artifacts=workflow_result.output.get("artifacts", []),
request_id=request_id,
dependencies=dependencies,
mode=MODE_PROPOSE_CRITIC,
metadata={"iterations": len(critique_iterations)},
tool_results=[],
model_response=callbacks.last_model_response,
)
return attach_runtime_metadata(
agent_result=result,
requested_mode=MODE_PROPOSE_CRITIC,
resolved_mode=MODE_PROPOSE_CRITIC,
budget_metadata=budget_tracker.as_metadata(),
extra_metadata={
"loop": {
"iterations": loop_output.get("iterations", max_iterations),
"iterations_executed": loop_output.get("iterations_executed", 0),
"terminated_reason": loop_terminated_reason,
}
},
)
__all__ = [
"ProposeCriticPattern",
]