Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions codex-rs/core/src/session/input_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,10 @@ impl InputQueue {
}

impl TurnInputQueue {
pub(crate) fn is_empty(&self) -> bool {
self.items.is_empty()
}

fn has_user_input(&self) -> bool {
self.items
.iter()
Expand Down
149 changes: 149 additions & 0 deletions codex-rs/core/src/session/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9213,6 +9213,61 @@ impl SessionTask for CompletingTask {
}
}

struct PendingInputContinuationTask {
final_pending_input_check_reached: Arc<tokio::sync::Notify>,
allow_initial_run_to_finish: Arc<tokio::sync::Notify>,
}

impl SessionTask for PendingInputContinuationTask {
fn kind(&self) -> TaskKind {
TaskKind::Regular
}

fn span_name(&self) -> &'static str {
"session_task.pending_input_continuation"
}

async fn run(
self: Arc<Self>,
session: Arc<SessionTaskContext>,
_ctx: Arc<TurnContext>,
_input: Vec<TurnInput>,
_cancellation_token: CancellationToken,
) -> SessionTaskResult {
let session = session.clone_session();
assert!(
!session
.input_queue
.has_pending_input(&session.active_turn)
.await
);
self.final_pending_input_check_reached.notify_one();
self.allow_initial_run_to_finish.notified().await;
Ok(None)
}

fn supports_pending_input_continuation(&self) -> bool {
true
}

async fn run_pending_input_continuation(
self: Arc<Self>,
session: Arc<SessionTaskContext>,
ctx: Arc<TurnContext>,
cancellation_token: CancellationToken,
) -> SessionTaskResult {
crate::session::turn::run_turn(
session.clone_session(),
ctx,
session.turn_extension_data(),
Vec::new(),
/*prewarmed_client_session*/ None,
cancellation_token,
)
.await
}
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum TerminalEventKind {
TurnComplete,
Expand Down Expand Up @@ -9733,6 +9788,100 @@ async fn task_finish_emits_turn_item_lifecycle_for_leftover_pending_user_input()
));
}

#[test]
fn task_finish_continues_input_accepted_after_final_pending_input_check() {
Comment thread
friel-openai marked this conversation as resolved.
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(2)
// Match the production runtime because this test executes the full sampling future.
.thread_stack_size(16 * 1024 * 1024)
.enable_all()
.build()
.expect("build test runtime");
runtime.block_on(task_finish_continues_late_input());
}

async fn task_finish_continues_late_input() {
let server = start_mock_server().await;
let response_mock = mount_sse_once(
&server,
sse(vec![
ev_response_created("response-1"),
ev_completed("response-1"),
]),
)
.await;
let base_url = server.uri();
let (session, turn_context, rx) = make_session_and_context_with_auth_and_config_and_rx(
CodexAuth::from_api_key("Test API Key"),
Vec::new(),
move |config| config.model_provider.base_url = Some(base_url),
)
.await;
let final_pending_input_check_reached = Arc::new(tokio::sync::Notify::new());
let allow_initial_run_to_finish = Arc::new(tokio::sync::Notify::new());

session
.spawn_task(
Arc::clone(&turn_context),
Vec::new(),
PendingInputContinuationTask {
final_pending_input_check_reached: Arc::clone(&final_pending_input_check_reached),
allow_initial_run_to_finish: Arc::clone(&allow_initial_run_to_finish),
},
)
.await;
timeout(
StdDuration::from_secs(2),
final_pending_input_check_reached.notified(),
)
.await
.expect("task should reach its final pending-input check");

let client_id = "late-steer-client-id";
session
.steer_input(
vec![UserInput::Text {
text: "late steer".to_string(),
text_elements: Vec::new(),
}],
/*additional_context*/ Default::default(),
Some(&turn_context.sub_id),
Some(client_id.to_string()),
/*responsesapi_client_metadata*/ None,
)
.await
.expect("steer should be accepted while the task is still active");
allow_initial_run_to_finish.notify_one();

let mut user_message_client_ids = Vec::new();
timeout(StdDuration::from_secs(15), async {
loop {
let event = rx.recv().await.expect("event channel should remain open");
match event.msg {
EventMsg::UserMessage(event) => {
user_message_client_ids.push(event.client_id);
}
EventMsg::TurnComplete(_) => break,
_ => {}
}
}
})
.await
.expect("continued task should complete");

let request = response_mock.single_request();
assert_eq!(
request
.message_input_texts("user")
.into_iter()
.filter(|text| text == "late steer")
.count(),
1
);
assert_eq!(user_message_client_ids, vec![Some(client_id.to_string())]);
assert!(session.active_turn.lock().await.is_none());
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn task_finish_emits_thread_idle_lifecycle_after_active_turn_clears() {
struct ThreadIdleRecorder {
Expand Down
141 changes: 112 additions & 29 deletions codex-rs/core/src/tasks/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ const TASK_COMPACT_METRIC: &str = "codex.task.compact";

pub(crate) type SessionTaskResult = CodexResult<Option<String>>;

/// Whether task finalization completed or reserved the active task for pending input.
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) enum TaskFinishAction {
Finish,
Continue,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum InterruptedTurnHistoryMarker {
Disabled,
Expand Down Expand Up @@ -237,6 +244,24 @@ pub(crate) trait SessionTask: Send + Sync + 'static {
cancellation_token: CancellationToken,
) -> impl std::future::Future<Output = SessionTaskResult> + Send;

/// Whether this task can consume input accepted while its initial run was finishing.
fn supports_pending_input_continuation(&self) -> bool {
false
}

/// Continues the task without repeating its turn-start lifecycle.
fn run_pending_input_continuation(
self: Arc<Self>,
session: Arc<SessionTaskContext>,
ctx: Arc<TurnContext>,
cancellation_token: CancellationToken,
) -> impl std::future::Future<Output = SessionTaskResult> + Send {
async move {
let _ = (self, session, ctx, cancellation_token);
unreachable!("task does not support pending input continuation")
}
}

/// Gives the task a chance to perform cleanup after an abort.
///
/// The default implementation is a no-op; override this if additional
Expand Down Expand Up @@ -266,6 +291,15 @@ pub(crate) trait AnySessionTask: Send + Sync + 'static {
cancellation_token: CancellationToken,
) -> BoxFuture<'static, SessionTaskResult>;

fn supports_pending_input_continuation(&self) -> bool;

fn run_pending_input_continuation(
self: Arc<Self>,
session: Arc<SessionTaskContext>,
ctx: Arc<TurnContext>,
cancellation_token: CancellationToken,
) -> BoxFuture<'static, SessionTaskResult>;

fn abort<'a>(
&'a self,
session: Arc<SessionTaskContext>,
Expand Down Expand Up @@ -301,6 +335,24 @@ where
))
}

fn supports_pending_input_continuation(&self) -> bool {
<T as SessionTask>::supports_pending_input_continuation(self)
}

fn run_pending_input_continuation(
self: Arc<Self>,
session: Arc<SessionTaskContext>,
ctx: Arc<TurnContext>,
cancellation_token: CancellationToken,
) -> BoxFuture<'static, SessionTaskResult> {
Box::pin(<T as SessionTask>::run_pending_input_continuation(
self,
session,
ctx,
cancellation_token,
))
}

fn abort<'a>(
&'a self,
session: Arc<SessionTaskContext>,
Expand Down Expand Up @@ -401,7 +453,7 @@ impl Session {
let handle = tokio::spawn(
async move {
let ctx_for_finish = Arc::clone(&ctx);
let task_result = task_for_run
let mut task_result = Arc::clone(&task_for_run)
.run(
Arc::clone(&session_ctx),
ctx,
Expand All @@ -411,22 +463,39 @@ impl Session {
.instrument(trace_span!("session_task.run"))
.await;
let sess = session_ctx.clone_session();
if let Err(err) = sess.flush_rollout().await {
warn!("failed to flush rollout before completing turn: {err}");
sess.send_event(
ctx_for_finish.as_ref(),
EventMsg::Warning(WarningEvent {
message: format!(
"Failed to save the conversation transcript; Codex will continue retrying. Error: {err}"
),
}),
)
.await;
}
if !task_cancellation_token.is_cancelled() {
// Finish uniformly from the spawn site so all tasks share the same lifecycle.
sess.on_task_finished(Arc::clone(&ctx_for_finish), task_result)
loop {
if let Err(err) = sess.flush_rollout().await {
warn!("failed to flush rollout before completing turn: {err}");
sess.send_event(
ctx_for_finish.as_ref(),
EventMsg::Warning(WarningEvent {
message: format!(
"Failed to save the conversation transcript; Codex will continue retrying. Error: {err}"
),
}),
)
.await;
}
if task_cancellation_token.is_cancelled() {
break;
}
// Finish uniformly from the spawn site so all tasks share the same lifecycle.
match sess
.on_task_finished(Arc::clone(&ctx_for_finish), task_result)
.await
{
TaskFinishAction::Finish => break,
TaskFinishAction::Continue => {
task_result = Arc::clone(&task_for_run)
.run_pending_input_continuation(
Arc::clone(&session_ctx),
Arc::clone(&ctx_for_finish),
task_cancellation_token.child_token(),
)
.instrument(trace_span!("session_task.run_pending_input_continuation"))
.await;
}
}
}
done_clone.notify_waiters();
}
Expand Down Expand Up @@ -560,11 +629,16 @@ impl Session {
true
}

#[expect(
clippy::await_holding_invalid_type,
reason = "task removal and the final pending-input check must remain atomic with steering"
)]
pub async fn on_task_finished(
self: &Arc<Self>,
turn_context: Arc<TurnContext>,
task_result: SessionTaskResult,
) {
) -> TaskFinishAction {
let can_continue = task_result.is_ok();
let (last_agent_message, abort_reason) = match task_result {
Ok(last_agent_message) => (last_agent_message, None),
Err(CodexErr::TurnAborted) => (None, Some(TurnAbortReason::Interrupted)),
Expand All @@ -573,21 +647,29 @@ impl Session {
(None, None)
}
};
turn_context
.turn_metadata_state
.cancel_git_enrichment_task();

let turn_state = {
let mut active = self.active_turn.lock().await;
active.as_mut().and_then(|active_turn| {
let task = active_turn.task.take()?;
task.handle.detach();
Some(Arc::clone(&active_turn.turn_state))
})
};
let Some(turn_state) = turn_state else {
return;
let Some(active_turn) = active.as_mut() else {
return TaskFinishAction::Finish;
};
let Some(task) = active_turn.task.as_ref() else {
return TaskFinishAction::Finish;
};
if can_continue && task.task.supports_pending_input_continuation() {
let turn_state = active_turn.turn_state.lock().await;
if !turn_state.pending_input.is_empty() {
return TaskFinishAction::Continue;
}
}
let Some(task) = active_turn.task.take() else {
return TaskFinishAction::Finish;
};
task.handle.detach();
Arc::clone(&active_turn.turn_state)
};
turn_context
.turn_metadata_state
.cancel_git_enrichment_task();
let pending_input = self
.input_queue
.take_pending_input_for_turn_state(turn_state.as_ref())
Expand Down Expand Up @@ -801,6 +883,7 @@ impl Session {
if let Err(err) = self.flush_rollout().await {
warn!("failed to flush rollout after emitting terminal turn event: {err}");
}
TaskFinishAction::Finish
}

async fn take_active_turn(&self) -> Option<ActiveTurn> {
Expand Down
Loading
Loading