From c5d74b9b3e18b1abf07d369346b5f7e68797ce51 Mon Sep 17 00:00:00 2001 From: Friel Date: Sat, 27 Jun 2026 05:55:47 +0000 Subject: [PATCH 1/2] Preserve late steer after turn finalization --- codex-rs/core/src/session/input_queue.rs | 4 + codex-rs/core/src/session/tests.rs | 149 +++++++++++++++++++++++ codex-rs/core/src/tasks/mod.rs | 141 ++++++++++++++++----- codex-rs/core/src/tasks/regular.rs | 30 +++++ 4 files changed, 295 insertions(+), 29 deletions(-) diff --git a/codex-rs/core/src/session/input_queue.rs b/codex-rs/core/src/session/input_queue.rs index e46a71892be5..863a8ba26b14 100644 --- a/codex-rs/core/src/session/input_queue.rs +++ b/codex-rs/core/src/session/input_queue.rs @@ -253,6 +253,10 @@ impl InputQueue { } impl TurnInputQueue { + pub(crate) fn is_empty(&self) -> bool { + self.items.is_empty() + } + fn has_user_input(&self) -> bool { self.items .iter() diff --git a/codex-rs/core/src/session/tests.rs b/codex-rs/core/src/session/tests.rs index d6336acd57a8..82fb9f8b7c1d 100644 --- a/codex-rs/core/src/session/tests.rs +++ b/codex-rs/core/src/session/tests.rs @@ -9213,6 +9213,61 @@ impl SessionTask for CompletingTask { } } +struct PendingInputContinuationTask { + final_pending_input_check_reached: Arc, + allow_initial_run_to_finish: Arc, +} + +impl SessionTask for PendingInputContinuationTask { + fn kind(&self) -> TaskKind { + TaskKind::Regular + } + + fn span_name(&self) -> &'static str { + "session_task.pending_input_continuation" + } + + async fn run( + self: Arc, + session: Arc, + _ctx: Arc, + _input: Vec, + _cancellation_token: CancellationToken, + ) -> SessionTaskResult { + let session = session.clone_session(); + assert!( + !session + .input_queue + .has_pending_input(&session.active_turn) + .await + ); + self.final_pending_input_check_reached.notify_one(); + self.allow_initial_run_to_finish.notified().await; + Ok(None) + } + + fn supports_pending_input_continuation(&self) -> bool { + true + } + + async fn run_pending_input_continuation( + self: Arc, + session: Arc, + ctx: Arc, + cancellation_token: CancellationToken, + ) -> SessionTaskResult { + crate::session::turn::run_turn( + session.clone_session(), + ctx, + session.turn_extension_data(), + Vec::new(), + /*prewarmed_client_session*/ None, + cancellation_token, + ) + .await + } +} + #[derive(Clone, Copy, Debug, PartialEq, Eq)] enum TerminalEventKind { TurnComplete, @@ -9733,6 +9788,100 @@ async fn task_finish_emits_turn_item_lifecycle_for_leftover_pending_user_input() )); } +#[test] +fn task_finish_continues_input_accepted_after_final_pending_input_check() { + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(2) + // Match the production runtime because this test executes the full sampling future. + .thread_stack_size(16 * 1024 * 1024) + .enable_all() + .build() + .expect("build test runtime"); + runtime.block_on(task_finish_continues_late_input()); +} + +async fn task_finish_continues_late_input() { + let server = start_mock_server().await; + let response_mock = mount_sse_once( + &server, + sse(vec![ + ev_response_created("response-1"), + ev_completed("response-1"), + ]), + ) + .await; + let base_url = server.uri(); + let (session, turn_context, rx) = make_session_and_context_with_auth_and_config_and_rx( + CodexAuth::from_api_key("Test API Key"), + Vec::new(), + move |config| config.model_provider.base_url = Some(base_url), + ) + .await; + let final_pending_input_check_reached = Arc::new(tokio::sync::Notify::new()); + let allow_initial_run_to_finish = Arc::new(tokio::sync::Notify::new()); + + session + .spawn_task( + Arc::clone(&turn_context), + Vec::new(), + PendingInputContinuationTask { + final_pending_input_check_reached: Arc::clone(&final_pending_input_check_reached), + allow_initial_run_to_finish: Arc::clone(&allow_initial_run_to_finish), + }, + ) + .await; + timeout( + StdDuration::from_secs(2), + final_pending_input_check_reached.notified(), + ) + .await + .expect("task should reach its final pending-input check"); + + let client_id = "late-steer-client-id"; + session + .steer_input( + vec![UserInput::Text { + text: "late steer".to_string(), + text_elements: Vec::new(), + }], + /*additional_context*/ Default::default(), + Some(&turn_context.sub_id), + Some(client_id.to_string()), + /*responsesapi_client_metadata*/ None, + ) + .await + .expect("steer should be accepted while the task is still active"); + allow_initial_run_to_finish.notify_one(); + + let mut user_message_client_ids = Vec::new(); + timeout(StdDuration::from_secs(15), async { + loop { + let event = rx.recv().await.expect("event channel should remain open"); + match event.msg { + EventMsg::UserMessage(event) => { + user_message_client_ids.push(event.client_id); + } + EventMsg::TurnComplete(_) => break, + _ => {} + } + } + }) + .await + .expect("continued task should complete"); + + let request = response_mock.single_request(); + assert_eq!( + request + .message_input_texts("user") + .into_iter() + .filter(|text| text == "late steer") + .count(), + 1 + ); + assert_eq!(user_message_client_ids, vec![Some(client_id.to_string())]); + assert!(session.active_turn.lock().await.is_none()); +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn task_finish_emits_thread_idle_lifecycle_after_active_turn_clears() { struct ThreadIdleRecorder { diff --git a/codex-rs/core/src/tasks/mod.rs b/codex-rs/core/src/tasks/mod.rs index c6d9d2990c2a..4ac5a5383840 100644 --- a/codex-rs/core/src/tasks/mod.rs +++ b/codex-rs/core/src/tasks/mod.rs @@ -69,6 +69,13 @@ const TASK_COMPACT_METRIC: &str = "codex.task.compact"; pub(crate) type SessionTaskResult = CodexResult>; +/// Whether task finalization completed or reserved the active task for pending input. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum TaskFinishAction { + Finish, + Continue, +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) enum InterruptedTurnHistoryMarker { Disabled, @@ -237,6 +244,24 @@ pub(crate) trait SessionTask: Send + Sync + 'static { cancellation_token: CancellationToken, ) -> impl std::future::Future + Send; + /// Whether this task can consume input accepted while its initial run was finishing. + fn supports_pending_input_continuation(&self) -> bool { + false + } + + /// Continues the task without repeating its turn-start lifecycle. + fn run_pending_input_continuation( + self: Arc, + session: Arc, + ctx: Arc, + cancellation_token: CancellationToken, + ) -> impl std::future::Future + Send { + async move { + let _ = (self, session, ctx, cancellation_token); + unreachable!("task does not support pending input continuation") + } + } + /// Gives the task a chance to perform cleanup after an abort. /// /// The default implementation is a no-op; override this if additional @@ -266,6 +291,15 @@ pub(crate) trait AnySessionTask: Send + Sync + 'static { cancellation_token: CancellationToken, ) -> BoxFuture<'static, SessionTaskResult>; + fn supports_pending_input_continuation(&self) -> bool; + + fn run_pending_input_continuation( + self: Arc, + session: Arc, + ctx: Arc, + cancellation_token: CancellationToken, + ) -> BoxFuture<'static, SessionTaskResult>; + fn abort<'a>( &'a self, session: Arc, @@ -301,6 +335,24 @@ where )) } + fn supports_pending_input_continuation(&self) -> bool { + ::supports_pending_input_continuation(self) + } + + fn run_pending_input_continuation( + self: Arc, + session: Arc, + ctx: Arc, + cancellation_token: CancellationToken, + ) -> BoxFuture<'static, SessionTaskResult> { + Box::pin(::run_pending_input_continuation( + self, + session, + ctx, + cancellation_token, + )) + } + fn abort<'a>( &'a self, session: Arc, @@ -401,7 +453,7 @@ impl Session { let handle = tokio::spawn( async move { let ctx_for_finish = Arc::clone(&ctx); - let task_result = task_for_run + let mut task_result = Arc::clone(&task_for_run) .run( Arc::clone(&session_ctx), ctx, @@ -411,22 +463,39 @@ impl Session { .instrument(trace_span!("session_task.run")) .await; let sess = session_ctx.clone_session(); - if let Err(err) = sess.flush_rollout().await { - warn!("failed to flush rollout before completing turn: {err}"); - sess.send_event( - ctx_for_finish.as_ref(), - EventMsg::Warning(WarningEvent { - message: format!( - "Failed to save the conversation transcript; Codex will continue retrying. Error: {err}" - ), - }), - ) - .await; - } - if !task_cancellation_token.is_cancelled() { - // Finish uniformly from the spawn site so all tasks share the same lifecycle. - sess.on_task_finished(Arc::clone(&ctx_for_finish), task_result) + loop { + if let Err(err) = sess.flush_rollout().await { + warn!("failed to flush rollout before completing turn: {err}"); + sess.send_event( + ctx_for_finish.as_ref(), + EventMsg::Warning(WarningEvent { + message: format!( + "Failed to save the conversation transcript; Codex will continue retrying. Error: {err}" + ), + }), + ) .await; + } + if task_cancellation_token.is_cancelled() { + break; + } + // Finish uniformly from the spawn site so all tasks share the same lifecycle. + match sess + .on_task_finished(Arc::clone(&ctx_for_finish), task_result) + .await + { + TaskFinishAction::Finish => break, + TaskFinishAction::Continue => { + task_result = Arc::clone(&task_for_run) + .run_pending_input_continuation( + Arc::clone(&session_ctx), + Arc::clone(&ctx_for_finish), + task_cancellation_token.child_token(), + ) + .instrument(trace_span!("session_task.run_pending_input_continuation")) + .await; + } + } } done_clone.notify_waiters(); } @@ -560,11 +629,16 @@ impl Session { true } + #[expect( + clippy::await_holding_invalid_type, + reason = "task removal and the final pending-input check must remain atomic with steering" + )] pub async fn on_task_finished( self: &Arc, turn_context: Arc, task_result: SessionTaskResult, - ) { + ) -> TaskFinishAction { + let can_continue = task_result.is_ok(); let (last_agent_message, abort_reason) = match task_result { Ok(last_agent_message) => (last_agent_message, None), Err(CodexErr::TurnAborted) => (None, Some(TurnAbortReason::Interrupted)), @@ -573,21 +647,29 @@ impl Session { (None, None) } }; - turn_context - .turn_metadata_state - .cancel_git_enrichment_task(); - let turn_state = { let mut active = self.active_turn.lock().await; - active.as_mut().and_then(|active_turn| { - let task = active_turn.task.take()?; - task.handle.detach(); - Some(Arc::clone(&active_turn.turn_state)) - }) - }; - let Some(turn_state) = turn_state else { - return; + let Some(active_turn) = active.as_mut() else { + return TaskFinishAction::Finish; + }; + let Some(task) = active_turn.task.as_ref() else { + return TaskFinishAction::Finish; + }; + if can_continue && task.task.supports_pending_input_continuation() { + let turn_state = active_turn.turn_state.lock().await; + if !turn_state.pending_input.is_empty() { + return TaskFinishAction::Continue; + } + } + let Some(task) = active_turn.task.take() else { + return TaskFinishAction::Finish; + }; + task.handle.detach(); + Arc::clone(&active_turn.turn_state) }; + turn_context + .turn_metadata_state + .cancel_git_enrichment_task(); let pending_input = self .input_queue .take_pending_input_for_turn_state(turn_state.as_ref()) @@ -801,6 +883,7 @@ impl Session { if let Err(err) = self.flush_rollout().await { warn!("failed to flush rollout after emitting terminal turn event: {err}"); } + TaskFinishAction::Finish } async fn take_active_turn(&self) -> Option { diff --git a/codex-rs/core/src/tasks/regular.rs b/codex-rs/core/src/tasks/regular.rs index 40837728a202..277065a3fed7 100644 --- a/codex-rs/core/src/tasks/regular.rs +++ b/codex-rs/core/src/tasks/regular.rs @@ -87,4 +87,34 @@ impl SessionTask for RegularTask { next_input = Vec::new(); } } + + fn supports_pending_input_continuation(&self) -> bool { + true + } + + async fn run_pending_input_continuation( + self: Arc, + session: Arc, + ctx: Arc, + cancellation_token: CancellationToken, + ) -> SessionTaskResult { + let sess = session.clone_session(); + let turn_extension_data = session.turn_extension_data(); + let run_turn_span = trace_span!("run_turn"); + loop { + let last_agent_message = run_turn( + Arc::clone(&sess), + Arc::clone(&ctx), + Arc::clone(&turn_extension_data), + Vec::new(), + /*prewarmed_client_session*/ None, + cancellation_token.child_token(), + ) + .instrument(run_turn_span.clone()) + .await?; + if !sess.input_queue.has_pending_input(&sess.active_turn).await { + return Ok(last_agent_message); + } + } + } } From 875de1ec111beed474f590f0d329f7b9cf1835e3 Mon Sep 17 00:00:00 2001 From: Friel Date: Sat, 27 Jun 2026 06:56:16 +0000 Subject: [PATCH 2/2] Add late steer integration coverage --- codex-rs/core/tests/suite/pending_input.rs | 156 +++++++++++++++++++++ 1 file changed, 156 insertions(+) diff --git a/codex-rs/core/tests/suite/pending_input.rs b/codex-rs/core/tests/suite/pending_input.rs index 9c9852a27644..7f3fbdcb1a25 100644 --- a/codex-rs/core/tests/suite/pending_input.rs +++ b/codex-rs/core/tests/suite/pending_input.rs @@ -1,5 +1,9 @@ use core_test_support::test_codex::local_selections; use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::sync::mpsc; +use std::time::Duration; use codex_core::CodexThread; use codex_core::config::CurrentTimeReminderConfig; @@ -39,6 +43,50 @@ use serde_json::Value; use serde_json::from_slice; use serde_json::json; use tokio::sync::oneshot; +use tracing::Subscriber; +use tracing::span::Id; +use tracing_subscriber::Layer; +use tracing_subscriber::layer::Context; +use tracing_subscriber::prelude::*; +use tracing_subscriber::registry::LookupSpan; + +/// Pauses task finalization after `RegularTask::run` has made its last pending-input decision. +/// +/// The test thread uses this span boundary to submit a steer through `CodexThread`, exercising the +/// production gap between the regular task returning and `Session::on_task_finished` atomically +/// deciding whether to continue or finish the turn. +struct RegularTaskRunCloseBarrier { + armed: Arc, + target_thread: std::thread::ThreadId, + run_closed: mpsc::SyncSender<()>, + steer_finished: std::sync::Mutex>, +} + +impl Layer for RegularTaskRunCloseBarrier +where + S: Subscriber + for<'lookup> LookupSpan<'lookup>, +{ + fn on_close(&self, id: Id, ctx: Context<'_, S>) { + let is_regular_task_run = ctx + .metadata(&id) + .is_some_and(|metadata| metadata.name() == "session_task.run"); + if std::thread::current().id() != self.target_thread + || !is_regular_task_run + || !self.armed.swap(false, Ordering::SeqCst) + { + return; + } + + self.run_closed + .send(()) + .expect("late-steer worker should wait for regular task completion"); + self.steer_finished + .lock() + .expect("late-steer barrier mutex should not be poisoned") + .recv_timeout(Duration::from_secs(15)) + .expect("late-steer worker should finish before task finalization resumes"); + } +} fn ev_message_item_done(id: &str, text: &str) -> Value { serde_json::json!({ @@ -287,6 +335,114 @@ async fn wait_for_sleep_item_completed(codex: &CodexThread, call_id: &str, durat ); } +#[tokio::test(flavor = "current_thread")] +async fn steer_accepted_after_regular_task_returns_continues_same_turn() { + const INITIAL_PROMPT: &str = "first prompt"; + const STEER_PROMPT: &str = "late steer"; + const CLIENT_ID: &str = "late-steer-client-id"; + const FINAL_MESSAGE: &str = "processed late steer"; + + let armed = Arc::new(AtomicBool::new(false)); + let (run_closed_tx, run_closed_rx) = mpsc::sync_channel(0); + let (steer_finished_tx, steer_finished_rx) = mpsc::sync_channel(0); + let subscriber = tracing_subscriber::registry().with(RegularTaskRunCloseBarrier { + armed: Arc::clone(&armed), + target_thread: std::thread::current().id(), + run_closed: run_closed_tx, + steer_finished: std::sync::Mutex::new(steer_finished_rx), + }); + // A global subscriber keeps callsite interest stable while other integration tests run in + // parallel. The layer only acts on this current-thread runtime and disarms after one span. + tracing::subscriber::set_global_default(subscriber) + .expect("pending-input integration test should install the global tracing subscriber"); + + let final_chunks = vec![ + chunk(ev_response_created("resp-2")), + chunk(ev_message_item_done("msg-2", FINAL_MESSAGE)), + chunk(ev_completed("resp-2")), + ]; + let (server, _completions) = + start_streaming_sse_server(vec![response_completed_chunks("resp-1"), final_chunks]).await; + let codex = build_codex(&server).await; + + let runtime_handle = tokio::runtime::Handle::current(); + let codex_for_steer = Arc::clone(&codex); + let steer_thread = std::thread::spawn(move || { + run_closed_rx + .recv_timeout(Duration::from_secs(15)) + .expect("regular task should reach its post-run finalization gap"); + runtime_handle.block_on(async { + codex_for_steer + .steer_input( + vec![UserInput::Text { + text: STEER_PROMPT.to_string(), + text_elements: Vec::new(), + }], + /*additional_context*/ Default::default(), + /*expected_turn_id*/ None, + /*client_user_message_id*/ + Some(CLIENT_ID.to_string()), + /*responsesapi_client_metadata*/ None, + ) + .await + .expect("steer should be accepted before task finalization"); + }); + steer_finished_tx + .send(()) + .expect("task finalization should wait for the steer"); + }); + + armed.store(true, Ordering::SeqCst); + submit_user_input(&codex, INITIAL_PROMPT).await; + + let mut turn_started_count = 0; + let mut late_user_message_client_ids = Vec::new(); + let mut received_final_message = false; + tokio::time::timeout(Duration::from_secs(30), async { + loop { + let event = codex + .next_event() + .await + .expect("event stream should remain open"); + match event.msg { + EventMsg::TurnStarted(_) => turn_started_count += 1, + EventMsg::UserMessage(message) if message.message == STEER_PROMPT => { + late_user_message_client_ids.push(message.client_id); + } + EventMsg::AgentMessage(message) if message.message == FINAL_MESSAGE => { + received_final_message = true; + } + EventMsg::TurnComplete(_) => break, + _ => {} + } + } + }) + .await + .expect("continued regular task should complete"); + steer_thread + .join() + .expect("late-steer worker should not panic"); + + let requests = server.requests().await; + assert_eq!(requests.len(), 2); + let continued_body: Value = from_slice(&requests[1]).expect("parse continued request"); + assert_eq!( + message_input_texts(&continued_body, "user") + .into_iter() + .filter(|text| text == STEER_PROMPT) + .count(), + 1 + ); + assert_eq!(turn_started_count, 1); + assert_eq!( + late_user_message_client_ids, + vec![Some(CLIENT_ID.to_string())] + ); + assert!(received_final_message); + + server.shutdown().await; +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn steer_interrupts_wait_agent_and_is_sent_in_follow_up_request() { const WAIT_CALL_ID: &str = "wait-call";