From 5eaa4d45069bd5af8ccfa34c3851303a0103618f Mon Sep 17 00:00:00 2001 From: Steven Lee Date: Fri, 26 Jun 2026 20:35:46 +0000 Subject: [PATCH 1/4] Route MCP OAuth recovery through Codex --- .../rmcp-client/src/http_client_adapter.rs | 40 ++- codex-rs/rmcp-client/src/lib.rs | 1 + codex-rs/rmcp-client/src/oauth.rs | 89 +---- .../src/oauth/refresh_transaction.rs | 171 +++++++--- .../src/oauth/tests/persistor_tests.rs | 63 +++- codex-rs/rmcp-client/src/oauth_transport.rs | 221 +++++++++++++ .../rmcp-client/src/oauth_transport_tests.rs | 189 +++++++++++ codex-rs/rmcp-client/src/rmcp_client.rs | 235 ++++++++----- .../rmcp-client/src/streamable_http_retry.rs | 119 ++++++- .../tests/streamable_http_oauth_internal.rs | 219 ++++++++++++ .../tests/streamable_http_oauth_startup.rs | 313 +++++++++++++++++- 11 files changed, 1435 insertions(+), 225 deletions(-) create mode 100644 codex-rs/rmcp-client/src/oauth_transport.rs create mode 100644 codex-rs/rmcp-client/src/oauth_transport_tests.rs create mode 100644 codex-rs/rmcp-client/tests/streamable_http_oauth_internal.rs diff --git a/codex-rs/rmcp-client/src/http_client_adapter.rs b/codex-rs/rmcp-client/src/http_client_adapter.rs index 19befb62355e..a7d934134812 100644 --- a/codex-rs/rmcp-client/src/http_client_adapter.rs +++ b/codex-rs/rmcp-client/src/http_client_adapter.rs @@ -22,6 +22,7 @@ use codex_exec_server::HttpResponseBodyStream; use futures::StreamExt; use futures::stream; use futures::stream::BoxStream; +use oauth2::AccessToken; use reqwest::StatusCode; use reqwest::header::ACCEPT; use reqwest::header::AUTHORIZATION; @@ -61,6 +62,10 @@ pub(crate) struct StreamableHttpClientAdapter { pub(crate) enum StreamableHttpClientAdapterError { #[error("streamable HTTP session expired with 404 Not Found")] SessionExpired404, + #[error("MCP server rejected the access token with HTTP 401 Unauthorized")] + AccessTokenRejected { rejected_access_token: AccessToken }, + #[error("MCP OAuth operation failed: {0:#}")] + OAuth(#[source] anyhow::Error), #[error(transparent)] HttpRequest(#[from] ExecServerError), #[error("invalid HTTP header: {0}")] @@ -109,7 +114,7 @@ impl StreamableHttpClient for StreamableHttpClientAdapter { JSON_MIME_TYPE.to_string(), StreamableHttpClientAdapterError::Header, )?; - if let Some(auth_token) = auth_token { + if let Some(auth_token) = auth_token.as_deref() { insert_header( &mut headers, AUTHORIZATION, @@ -162,6 +167,11 @@ impl StreamableHttpClient for StreamableHttpClientAdapter { StreamableHttpClientAdapterError::SessionExpired404, )); } + if response.status == StatusCode::UNAUTHORIZED.as_u16() + && let Some(error) = access_token_rejected(auth_token.as_deref()) + { + return Err(error); + } if response.status == StatusCode::UNAUTHORIZED.as_u16() && let Some(header) = response_header(&response.headers, reqwest::header::WWW_AUTHENTICATE) @@ -240,7 +250,7 @@ impl StreamableHttpClient for StreamableHttpClientAdapter { let mut headers = self.default_headers.clone(); headers.extend(custom_headers); self.add_auth_headers(&mut headers); - if let Some(auth_token) = auth_token { + if let Some(auth_token) = auth_token.as_deref() { insert_header( &mut headers, AUTHORIZATION, @@ -274,6 +284,11 @@ impl StreamableHttpClient for StreamableHttpClientAdapter { if response.status == StatusCode::METHOD_NOT_ALLOWED.as_u16() { return Ok(()); } + if response.status == StatusCode::UNAUTHORIZED.as_u16() + && let Some(error) = access_token_rejected(auth_token.as_deref()) + { + return Err(error); + } if !status_is_success(response.status) { return Err(StreamableHttpError::UnexpectedServerResponse( format!("DELETE returned HTTP {}", response.status).into(), @@ -316,7 +331,7 @@ impl StreamableHttpClient for StreamableHttpClientAdapter { StreamableHttpClientAdapterError::Header, )?; } - if let Some(auth_token) = auth_token { + if let Some(auth_token) = auth_token.as_deref() { insert_header( &mut headers, AUTHORIZATION, @@ -349,6 +364,11 @@ impl StreamableHttpClient for StreamableHttpClientAdapter { StreamableHttpClientAdapterError::SessionExpired404, )); } + if response.status == StatusCode::UNAUTHORIZED.as_u16() + && let Some(error) = access_token_rejected(auth_token.as_deref()) + { + return Err(error); + } if !status_is_success(response.status) { return Err(StreamableHttpError::UnexpectedServerResponse( format!("GET returned HTTP {}", response.status).into(), @@ -371,6 +391,20 @@ impl StreamableHttpClient for StreamableHttpClientAdapter { } } +fn access_token_rejected( + auth_token: Option<&str>, +) -> Option> { + // Preserve the token associated with this response. Reading the current credential after a + // delayed 401 is racy: another concurrent request may already have refreshed A to B, in which + // case recovery must retry B rather than refresh B a second time. AccessToken's Debug + // implementation redacts the secret if this error is logged. + auth_token.map(|rejected_access_token| { + StreamableHttpError::Client(StreamableHttpClientAdapterError::AccessTokenRejected { + rejected_access_token: AccessToken::new(rejected_access_token.to_string()), + }) + }) +} + impl StreamableHttpClientAdapter { fn add_auth_headers(&self, headers: &mut HeaderMap) { if let Some(auth_provider) = &self.auth_provider { diff --git a/codex-rs/rmcp-client/src/lib.rs b/codex-rs/rmcp-client/src/lib.rs index 74dd1697bebd..84eba64fed3f 100644 --- a/codex-rs/rmcp-client/src/lib.rs +++ b/codex-rs/rmcp-client/src/lib.rs @@ -6,6 +6,7 @@ mod in_process_transport; mod logging_client_handler; mod oauth; mod oauth_http_client; +mod oauth_transport; mod perform_oauth_login; mod program_resolver; mod rmcp_client; diff --git a/codex-rs/rmcp-client/src/oauth.rs b/codex-rs/rmcp-client/src/oauth.rs index 136813ea7882..bf11a5d5a2bf 100644 --- a/codex-rs/rmcp-client/src/oauth.rs +++ b/codex-rs/rmcp-client/src/oauth.rs @@ -64,6 +64,7 @@ use codex_keyring_store::DefaultKeyringStore; use codex_keyring_store::KeyringStore; use codex_utils_home_dir::find_codex_home; +pub(crate) use self::refresh_transaction::request_oauth_token_response; pub(crate) use self::resolved_store::ResolvedOAuthCredentialStore; pub(crate) use self::resolved_store::ResolvedOAuthTokens; pub(crate) use self::resolved_store::load_oauth_tokens_from_store; @@ -71,7 +72,8 @@ pub(crate) use self::resolved_store::resolve_oauth_tokens; const KEYRING_SERVICE: &str = "Codex MCP Credentials"; const MCP_OAUTH_SECRET_PREFIX: &str = "MCP_OAUTH"; -const REFRESH_SKEW_MILLIS: u64 = 30_000; +// Refresh proactively so ordinary requests do not race token expiry. +const REFRESH_SKEW_MILLIS: u64 = 60_000; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct StoredOAuthTokens { @@ -501,71 +503,6 @@ impl OAuthPersistor { }), } } - - /// Persists a refresh that RMCP performed through its existing internal fallback path. - /// - /// Codex preflight refreshes are serialized by `refresh_transaction`. RMCP keeps its current - /// refresh capability until the next independently mergeable layer installs Codex recovery - /// for every transport path, so this compatibility path must remain in place for now. - #[expect( - clippy::await_holding_invalid_type, - reason = "AuthorizationManager async access must be serialized through its Tokio mutex" - )] - pub(crate) async fn persist_if_needed(&self) -> Result<()> { - let (client_id, maybe_credentials) = { - let manager = self.inner.authorization_manager.clone(); - let guard = manager.lock().await; - guard.get_credentials().await - }?; - - match maybe_credentials { - Some(credentials) => { - let mut current_credentials = self.inner.current_credentials.lock().await; - let new_token_response = WrappedOAuthTokenResponse(credentials.clone()); - let same_token = current_credentials - .as_ref() - .map(|previous| previous.token_response == new_token_response) - .unwrap_or(false); - let expires_at = if same_token { - current_credentials - .as_ref() - .and_then(|previous| previous.expires_at) - } else { - compute_expires_at_millis(&credentials) - }; - let stored = StoredOAuthTokens { - server_name: self.inner.server_name.clone(), - url: self.inner.url.clone(), - client_id, - token_response: new_token_response, - expires_at, - }; - if current_credentials.as_ref() != Some(&stored) { - save_to_resolved_store(&DefaultKeyringStore, &self.inner, &stored)?; - *current_credentials = Some(stored); - } - } - None => { - let mut current_credentials = self.inner.current_credentials.lock().await; - if current_credentials.take().is_some() - && let Err(error) = delete_from_resolved_store( - &DefaultKeyringStore, - &self.inner.server_name, - &self.inner.url, - self.inner.credential_store, - ) - { - warn!( - server_name = %self.inner.server_name, - error = %error, - "failed to remove MCP OAuth credentials from the resolved store" - ); - } - } - } - - Ok(()) - } } fn save_to_resolved_store( @@ -586,26 +523,6 @@ fn save_to_resolved_store( } } -fn delete_from_resolved_store( - keyring_store: &K, - server_name: &str, - url: &str, - credential_store: ResolvedOAuthCredentialStore, -) -> Result { - match credential_store { - ResolvedOAuthCredentialStore::File => { - let key = compute_store_key(server_name, url)?; - delete_oauth_tokens_from_file(&key) - } - ResolvedOAuthCredentialStore::Keyring(AuthKeyringBackendKind::Direct) => { - delete_oauth_tokens_from_direct_keyring(keyring_store, server_name, url) - } - ResolvedOAuthCredentialStore::Keyring(AuthKeyringBackendKind::Secrets) => { - delete_oauth_tokens_from_secrets_keyring(keyring_store, server_name, url) - } - } -} - const FALLBACK_FILENAME: &str = ".credentials.json"; const MCP_SERVER_TYPE: &str = "http"; diff --git a/codex-rs/rmcp-client/src/oauth/refresh_transaction.rs b/codex-rs/rmcp-client/src/oauth/refresh_transaction.rs index a665160f6184..61283e132451 100644 --- a/codex-rs/rmcp-client/src/oauth/refresh_transaction.rs +++ b/codex-rs/rmcp-client/src/oauth/refresh_transaction.rs @@ -10,6 +10,7 @@ use anyhow::Context; use anyhow::Result; use codex_keyring_store::DefaultKeyringStore; use codex_keyring_store::KeyringStore; +use oauth2::AccessToken; use oauth2::TokenResponse; use rmcp::transport::auth::AuthError; use rmcp::transport::auth::AuthorizationManager; @@ -36,34 +37,45 @@ const REFRESH_REQUEST_TIMEOUT: Duration = Duration::from_secs(45); impl OAuthPersistor { pub(crate) async fn refresh_if_needed(&self) -> Result<()> { - self.refresh_if_needed_in(&DefaultKeyringStore, REFRESH_REQUEST_TIMEOUT) - .await + self.refresh_in( + DefaultKeyringStore, + RefreshReason::Expiry, + REFRESH_REQUEST_TIMEOUT, + ) + .await } - /// Injects the credential backend and provider timeout for deterministic failure-path tests. - pub(super) async fn refresh_if_needed_in( + pub(crate) async fn refresh_after_unauthorized( &self, - keyring_store: &K, - refresh_request_timeout: Duration, + rejected_access_token: AccessToken, ) -> Result<()> { - let expires_at = { - let guard = self.inner.current_credentials.lock().await; - guard.as_ref().and_then(|tokens| tokens.expires_at) - }; - - if !token_needs_refresh(expires_at) { - return Ok(()); - } - - self.run_owned_refresh_transaction(keyring_store.clone(), refresh_request_timeout) - .await + self.refresh_in( + DefaultKeyringStore, + RefreshReason::Unauthorized { + rejected_access_token, + }, + REFRESH_REQUEST_TIMEOUT, + ) + .await } - async fn run_owned_refresh_transaction( + /// Injects the credential backend and provider timeout for deterministic failure-path tests. + pub(super) async fn refresh_in( &self, keyring_store: K, + reason: RefreshReason, refresh_request_timeout: Duration, ) -> Result<()> { + if matches!(reason, RefreshReason::Expiry) { + let expires_at = { + let guard = self.inner.current_credentials.lock().await; + guard.as_ref().and_then(|tokens| tokens.expires_at) + }; + if !token_needs_refresh(expires_at) { + return Ok(()); + } + } + let persistor = self.clone(); let server_name = self.inner.server_name.clone(); // Once a provider request can consume a rotating refresh token, dropping the caller's @@ -75,16 +87,17 @@ impl OAuthPersistor { // permits a later serialized retry. Some providers accept the previous token during a // grace period; otherwise that retry surfaces reauthorization. We accept that residual // token-family-revocation risk rather than holding the lock indefinitely. + let refresh_reason = reason.as_str(); tokio::spawn(async move { let result = persistor - .refresh_transaction(&keyring_store, refresh_request_timeout) + .refresh_transaction(&keyring_store, reason, refresh_request_timeout) .await; // Keep this summary inside the owned task so caller cancellation cannot suppress it. if let Err(error) = &result { warn!( server_name = %persistor.inner.server_name, - refresh_reason = "expiry", + refresh_reason, error = %error, "MCP OAuth refresh transaction failed" ); @@ -105,13 +118,14 @@ impl OAuthPersistor { skip_all, fields( server_name = %self.inner.server_name, - refresh_reason = "expiry", + refresh_reason = reason.as_str(), ), err )] async fn refresh_transaction( &self, keyring_store: &K, + reason: RefreshReason, refresh_request_timeout: Duration, ) -> Result<()> { let transaction_started_at = Instant::now(); @@ -149,7 +163,19 @@ impl OAuthPersistor { ); }; - if !token_needs_refresh(latest.expires_at) { + let latest_access_token = latest.token_response.0.access_token().secret(); + // Expiry refresh can adopt any reread that is now healthy. A 401 belongs to the access + // token sent with that specific request, not to this client's mutable current snapshot. + // If a delayed request rejected A after another request refreshed A to B, adopt B and let + // the caller retry instead of rotating B again. + let should_adopt = !token_needs_refresh(latest.expires_at) + && match &reason { + RefreshReason::Expiry => true, + RefreshReason::Unauthorized { + rejected_access_token, + } => rejected_access_token.secret() != latest_access_token, + }; + if should_adopt { debug!("adopting newer MCP OAuth credentials without contacting the provider"); self.adopt_credentials(latest).await?; return Ok(()); @@ -180,9 +206,14 @@ impl OAuthPersistor { // the guard prevents ordinary requests from observing credentials while they are staged // and committed. let mut guard = manager.lock().await; - install_tokens_in_manager_guard(&mut guard, &latest) - .await - .context("failed to stage OAuth credentials for refresh")?; + if let Err(error) = + install_tokens_in_manager_guard(&mut guard, &latest, CredentialExposure::Refresh).await + { + install_tokens_in_manager_guard(&mut guard, &latest, CredentialExposure::Request) + .await + .context("failed to restore request-only OAuth credentials")?; + return Err(error).context("failed to stage OAuth credentials for refresh"); + } // The provider request has its own bound. The independently owned task prevents caller // startup and operation deadlines from canceling this future after the provider may have // rotated the refresh token. @@ -191,13 +222,13 @@ impl OAuthPersistor { timeout_ms = refresh_request_timeout.as_millis(), "requesting refreshed MCP OAuth credentials from the provider" ); - let refreshed = match timeout(refresh_request_timeout, guard.refresh_token()).await { + let refresh_result = match timeout(refresh_request_timeout, guard.refresh_token()).await { Ok(Ok(token_response)) => { debug!( provider_elapsed_ms = provider_started_at.elapsed().as_millis(), "received refreshed MCP OAuth credentials from the provider" ); - refreshed_tokens(token_response, &latest, &self.inner) + Ok(refreshed_tokens(token_response, &latest, &self.inner)) } Ok(Err(error)) => { warn!( @@ -205,12 +236,12 @@ impl OAuthPersistor { error = %error, "MCP OAuth provider refresh failed" ); - return Err(error).with_context(|| { + Err(error).with_context(|| { format!( "failed to refresh OAuth tokens for server {}", self.inner.server_name ) - }); + }) } Err(_) => { warn!( @@ -218,12 +249,17 @@ impl OAuthPersistor { timeout_ms = refresh_request_timeout.as_millis(), "MCP OAuth provider refresh timed out; the outcome is unknown and a later serialized retry is permitted" ); - anyhow::bail!( + Err(anyhow::anyhow!( "timed out after {refresh_request_timeout:?} refreshing OAuth tokens for server {}", self.inner.server_name - ); + )) } }; + let request_tokens = refresh_result.as_ref().unwrap_or(&latest); + install_tokens_in_manager_guard(&mut guard, request_tokens, CredentialExposure::Request) + .await + .context("failed to restore request-only OAuth credentials")?; + let refreshed = refresh_result?; // Once the provider rotates a refresh token, the owned task must attempt persistence even // if the caller's deadline expires in the meantime. Refresh persistence stays on the @@ -246,21 +282,18 @@ impl OAuthPersistor { error = %error, "failed to persist refreshed MCP OAuth credentials; returning the error and restoring the previous in-process credentials" ); - install_tokens_in_manager_guard(&mut guard, &latest) + install_tokens_in_manager_guard(&mut guard, &latest, CredentialExposure::Request) .await .context( - "failed to restore previous OAuth credentials after refresh persistence failed", + "failed to restore previous request-only OAuth credentials after refresh persistence failed", )?; return Err(error); } - // This independently mergeable layer intentionally retains RMCP's legacy - // `persist_if_needed` compatibility path until Codex owns every transport refresh in the - // next layer. `guard.refresh_token()` installed the provider's raw response, which may - // omit a refresh token or scopes that `refreshed_tokens` carried forward. Commit the same - // merged object to RMCP before releasing the guard so the compatibility path cannot write - // the raw partial response back over the authoritative durable credential. - install_tokens_in_manager_guard(&mut guard, &refreshed) + // `guard.refresh_token()` installed the provider's raw response, which may omit fields + // that `refreshed_tokens` carried forward. Commit the merged object before releasing the + // guard, then expose only its request-safe form so RMCP cannot refresh independently. + install_tokens_in_manager_guard(&mut guard, &refreshed, CredentialExposure::Request) .await .context( "refreshed OAuth tokens were persisted but could not be installed in the authorization manager", @@ -288,6 +321,20 @@ impl OAuthPersistor { } } +pub(super) enum RefreshReason { + Expiry, + Unauthorized { rejected_access_token: AccessToken }, +} + +impl RefreshReason { + fn as_str(&self) -> &'static str { + match self { + Self::Expiry => "expiry", + Self::Unauthorized { .. } => "unauthorized", + } + } +} + #[expect( clippy::await_holding_invalid_type, reason = "AuthorizationManager async access must be serialized through its Tokio mutex" @@ -298,16 +345,17 @@ async fn install_tokens_in_manager( ) -> Result<()> { let manager = authorization_manager.clone(); let mut guard = manager.lock().await; - install_tokens_in_manager_guard(&mut guard, tokens).await + install_tokens_in_manager_guard(&mut guard, tokens, CredentialExposure::Request).await } async fn install_tokens_in_manager_guard( authorization_manager: &mut AuthorizationManager, tokens: &StoredOAuthTokens, + exposure: CredentialExposure, ) -> Result<()> { let store = InMemoryCredentialStore::new(); store - .save(stored_credentials_from_tokens(tokens)) + .save(stored_credentials_from_tokens(tokens, exposure)) .await .context("failed to stage OAuth tokens for authorization manager")?; @@ -323,16 +371,36 @@ async fn install_tokens_in_manager_guard( Ok(()) } -fn stored_credentials_from_tokens(tokens: &StoredOAuthTokens) -> StoredCredentials { - let token_response = tokens.token_response.0.clone(); +/// Controls which credentials RMCP may observe. +/// +/// Ordinary transport requests get neither the refresh token nor expiry metadata, so RMCP cannot +/// refresh outside Codex's cross-process transaction. Full credentials exist in the manager only +/// while that transaction owns both the credential lock and the manager mutex. +#[derive(Clone, Copy)] +enum CredentialExposure { + Request, + Refresh, +} + +fn stored_credentials_from_tokens( + tokens: &StoredOAuthTokens, + exposure: CredentialExposure, +) -> StoredCredentials { + let token_response = match exposure { + CredentialExposure::Request => request_oauth_token_response(tokens), + CredentialExposure::Refresh => tokens.token_response.0.clone(), + }; let granted_scopes = token_response .scopes() .map(|scopes| scopes.iter().map(|scope| scope.to_string()).collect()) .unwrap_or_default(); - let token_received_at = SystemTime::now() - .duration_since(UNIX_EPOCH) - .ok() - .map(|duration| duration.as_secs()); + let token_received_at = match exposure { + CredentialExposure::Request => None, + CredentialExposure::Refresh => SystemTime::now() + .duration_since(UNIX_EPOCH) + .ok() + .map(|duration| duration.as_secs()), + }; StoredCredentials::new( tokens.client_id.clone(), @@ -342,6 +410,13 @@ fn stored_credentials_from_tokens(tokens: &StoredOAuthTokens) -> StoredCredentia ) } +pub(crate) fn request_oauth_token_response(tokens: &StoredOAuthTokens) -> OAuthTokenResponse { + let mut token_response = tokens.token_response.0.clone(); + token_response.set_refresh_token(None); + token_response.set_expires_in(None); + token_response +} + fn refreshed_tokens( mut token_response: OAuthTokenResponse, previous: &StoredOAuthTokens, diff --git a/codex-rs/rmcp-client/src/oauth/tests/persistor_tests.rs b/codex-rs/rmcp-client/src/oauth/tests/persistor_tests.rs index 72b213e119f7..9da31fcb412e 100644 --- a/codex-rs/rmcp-client/src/oauth/tests/persistor_tests.rs +++ b/codex-rs/rmcp-client/src/oauth/tests/persistor_tests.rs @@ -31,6 +31,7 @@ use crate::oauth::fallback_file_path; use crate::oauth::load_oauth_tokens_from_file; use crate::oauth::load_oauth_tokens_from_keyring; use crate::oauth::refresh_lock::RefreshCredentialLock; +use crate::oauth::refresh_transaction::RefreshReason; use crate::oauth::save_oauth_tokens_to_file; use crate::oauth::save_oauth_tokens_with_keyring; @@ -59,10 +60,6 @@ async fn concurrent_refreshes_call_provider_once_and_carry_omitted_fields() -> R second_task.await??; server.verify().await; - // Layer 2 still invokes the legacy RMCP persistence hook after operations. Exercise that hook - // so a raw provider response that omitted refresh token/scopes cannot overwrite the merged - // authoritative credential. - first.persist_if_needed().await?; let stored = load_oauth_tokens_from_file(&initial.server_name, &initial.url)? .expect("refreshed credentials should be stored"); let mut expected_response = initial.token_response.0.clone(); @@ -78,6 +75,53 @@ async fn concurrent_refreshes_call_provider_once_and_carry_omitted_fields() -> R Ok(()) } +#[tokio::test(flavor = "current_thread")] +async fn delayed_unauthorized_retries_adopt_the_winning_token() -> Result<()> { + let _env = TempCodexHome::new(); + let server = MockServer::start().await; + mount_oauth_metadata(&server).await; + let _refresh_started = mount_delayed_refresh(&server, "refreshed-access-token").await; + let mut initial = expired_tokens(&format!("{}/mcp", server.uri())); + initial.expires_at = None; + initial.token_response.0.set_expires_in(None); + save_oauth_tokens_to_file(&initial)?; + + let first = persistor_for(&initial).await?; + let second_manager = authorization_manager_for(&initial).await?; + let second = OAuthPersistor::new( + initial.server_name.clone(), + initial.url.clone(), + Arc::clone(&second_manager), + ResolvedOAuthCredentialStore::File, + Some(initial.clone()), + ); + let rejected_access_token = initial.token_response.0.access_token().clone(); + + first + .refresh_after_unauthorized(rejected_access_token.clone()) + .await?; + // Both calls model requests that left with A. Once the first 401 rotates A to B, later 401s + // must adopt B and retry their requests instead of rotating B again. + first + .refresh_after_unauthorized(rejected_access_token.clone()) + .await?; + second + .refresh_after_unauthorized(rejected_access_token) + .await?; + + server.verify().await; + let stored = load_oauth_tokens_from_file(&initial.server_name, &initial.url)? + .expect("the winning refresh should be persisted"); + assert_eq!( + stored.token_response.0.access_token().secret(), + "refreshed-access-token" + ); + let adopted = tokens_from_manager(&second_manager).await?; + assert_eq!(adopted.0.access_token().secret(), "refreshed-access-token"); + assert!(adopted.0.refresh_token().is_none()); + Ok(()) +} + #[tokio::test(flavor = "current_thread")] async fn resolved_keyring_read_error_preserves_in_memory_credentials() -> Result<()> { let _env = TempCodexHome::new(); @@ -97,7 +141,11 @@ async fn resolved_keyring_read_error_preserves_in_memory_credentials() -> Result ); let error = persistor - .refresh_if_needed_in(&keyring_store, Duration::from_secs(/*secs*/ 45)) + .refresh_in( + keyring_store, + RefreshReason::Expiry, + Duration::from_secs(/*secs*/ 45), + ) .await .expect_err("the resolved keyring read error should abort refresh"); assert!( @@ -162,8 +210,9 @@ async fn provider_timeout_releases_lock_and_preserves_durable_credentials() -> R let persistor = persistor_for(&initial).await?; let error = persistor - .refresh_if_needed_in( - &MockKeyringStore::default(), + .refresh_in( + MockKeyringStore::default(), + RefreshReason::Expiry, Duration::from_millis(/*millis*/ 50), ) .await diff --git a/codex-rs/rmcp-client/src/oauth_transport.rs b/codex-rs/rmcp-client/src/oauth_transport.rs new file mode 100644 index 000000000000..b6e9556de888 --- /dev/null +++ b/codex-rs/rmcp-client/src/oauth_transport.rs @@ -0,0 +1,221 @@ +//! Codex-owned OAuth policy for RMCP Streamable HTTP traffic. +//! +//! RMCP remains responsible for transport mechanics and bearer-token injection. Codex owns the +//! credential lifecycle: every POST, SSE GET/reconnect, and session DELETE receives proactive +//! refresh from its owning Codex layer, and each path has at most one 401 recovery. The +//! authorization manager only receives request-safe credentials, so it cannot independently +//! refresh outside Codex's serialized transaction. +//! +//! POST recovery is split at an intentional ownership boundary. Client-originated requests and +//! notifications retain their outer `RmcpClient` recovery, which knows the startup/tool deadline +//! and can avoid replaying a request after its caller timed out. RMCP-owned responses to +//! server-initiated requests have no such outer operation, so they recover here. GET/reconnect and +//! DELETE are always RMCP-owned and also recover here. + +use std::collections::HashMap; +use std::sync::Arc; + +use reqwest::header::HeaderName; +use reqwest::header::HeaderValue; +use rmcp::model::ClientJsonRpcMessage; +use rmcp::model::JsonRpcMessage; +use rmcp::transport::auth::AuthClient; +use rmcp::transport::streamable_http_client::StreamableHttpClient; +use rmcp::transport::streamable_http_client::StreamableHttpError; +use rmcp::transport::streamable_http_client::StreamableHttpPostResponse; +use tracing::debug; + +use crate::http_client_adapter::StreamableHttpClientAdapter; +use crate::http_client_adapter::StreamableHttpClientAdapterError; +use crate::oauth::OAuthPersistor; + +type TransportResult = + std::result::Result>; + +#[derive(Clone)] +pub(crate) struct OAuthTransportClient { + auth_client: AuthClient, + persistor: OAuthPersistor, +} + +impl OAuthTransportClient { + pub(crate) fn new( + auth_client: AuthClient, + persistor: OAuthPersistor, + ) -> Self { + Self { + auth_client, + persistor, + } + } + + pub(crate) fn persistor(&self) -> OAuthPersistor { + self.persistor.clone() + } + + async fn preflight(&self, operation: &'static str) -> TransportResult<()> { + debug!( + operation, + "checking MCP OAuth credentials before transport request" + ); + self.persistor + .refresh_if_needed() + .await + .map_err(oauth_transport_error) + } + + async fn recover_after_unauthorized( + &self, + operation: &'static str, + rejected_access_token: Option, + ) -> TransportResult { + let Some(rejected_access_token) = rejected_access_token else { + return Ok(false); + }; + + debug!( + operation, + "recovering once after MCP transport rejected an OAuth access token" + ); + self.persistor + .refresh_after_unauthorized(rejected_access_token) + .await + .map_err(oauth_transport_error)?; + Ok(true) + } +} + +impl StreamableHttpClient for OAuthTransportClient { + type Error = StreamableHttpClientAdapterError; + + async fn post_message( + &self, + uri: Arc, + message: ClientJsonRpcMessage, + session_id: Option>, + auth_token: Option, + custom_headers: HashMap, + ) -> TransportResult { + let is_rmcp_owned_response = matches!( + message, + JsonRpcMessage::Response(_) | JsonRpcMessage::Error(_) + ); + if is_rmcp_owned_response { + self.preflight("post_message").await?; + } + let result = self + .auth_client + .post_message( + Arc::clone(&uri), + message.clone(), + session_id.clone(), + auth_token.clone(), + custom_headers.clone(), + ) + .await; + + // RMCP queues client-originated requests independently of the caller waiting on them. If + // recovery happened here, a timed-out public tool call could still be replayed after its + // refresh finished. The outer RmcpClient path owns those deadlines. Responses to + // server-initiated requests have no outer operation and therefore recover here. + if !is_rmcp_owned_response { + return result; + } + let rejected_access_token = result.as_ref().err().and_then(rejected_access_token); + if self + .recover_after_unauthorized("post_message", rejected_access_token) + .await? + { + self.auth_client + .post_message(uri, message, session_id, auth_token, custom_headers) + .await + } else { + result + } + } + + async fn delete_session( + &self, + uri: Arc, + session_id: Arc, + auth_token: Option, + custom_headers: HashMap, + ) -> TransportResult<()> { + self.preflight("delete_session").await?; + let result = self + .auth_client + .delete_session( + Arc::clone(&uri), + Arc::clone(&session_id), + auth_token.clone(), + custom_headers.clone(), + ) + .await; + let rejected_access_token = result.as_ref().err().and_then(rejected_access_token); + if self + .recover_after_unauthorized("delete_session", rejected_access_token) + .await? + { + self.auth_client + .delete_session(uri, session_id, auth_token, custom_headers) + .await + } else { + result + } + } + + async fn get_stream( + &self, + uri: Arc, + session_id: Arc, + last_event_id: Option, + auth_token: Option, + custom_headers: HashMap, + ) -> TransportResult< + futures::stream::BoxStream<'static, Result>, + > { + self.preflight("get_stream").await?; + let result = self + .auth_client + .get_stream( + Arc::clone(&uri), + Arc::clone(&session_id), + last_event_id.clone(), + auth_token.clone(), + custom_headers.clone(), + ) + .await; + let rejected_access_token = result.as_ref().err().and_then(rejected_access_token); + if self + .recover_after_unauthorized("get_stream", rejected_access_token) + .await? + { + self.auth_client + .get_stream(uri, session_id, last_event_id, auth_token, custom_headers) + .await + } else { + result + } + } +} + +fn rejected_access_token( + error: &StreamableHttpError, +) -> Option { + match error { + StreamableHttpError::Client(StreamableHttpClientAdapterError::AccessTokenRejected { + rejected_access_token, + }) => Some(rejected_access_token.clone()), + _ => None, + } +} + +fn oauth_transport_error( + error: anyhow::Error, +) -> StreamableHttpError { + StreamableHttpError::Client(StreamableHttpClientAdapterError::OAuth(error)) +} + +#[cfg(test)] +#[path = "oauth_transport_tests.rs"] +mod tests; diff --git a/codex-rs/rmcp-client/src/oauth_transport_tests.rs b/codex-rs/rmcp-client/src/oauth_transport_tests.rs new file mode 100644 index 000000000000..31b9e759b70d --- /dev/null +++ b/codex-rs/rmcp-client/src/oauth_transport_tests.rs @@ -0,0 +1,189 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use codex_config::types::AuthKeyringBackendKind; +use codex_config::types::OAuthCredentialsStoreMode; +use codex_exec_server::Environment; +use oauth2::AccessToken; +use oauth2::RefreshToken; +use oauth2::basic::BasicTokenType; +use reqwest::header::HeaderMap; +use rmcp::model::ClientJsonRpcMessage; +use rmcp::transport::auth::AuthClient; +use rmcp::transport::auth::OAuthState; +use rmcp::transport::auth::OAuthTokenResponse; +use rmcp::transport::auth::VendorExtraTokenFields; +use rmcp::transport::streamable_http_client::StreamableHttpClient; +use rmcp::transport::streamable_http_client::StreamableHttpPostResponse; +use serde_json::json; +use tempfile::TempDir; +use tokio::process::Command; +use wiremock::Mock; +use wiremock::MockServer; +use wiremock::ResponseTemplate; +use wiremock::matchers::body_string_contains; +use wiremock::matchers::header; +use wiremock::matchers::method; +use wiremock::matchers::path; + +use super::OAuthTransportClient; +use crate::http_client_adapter::StreamableHttpClientAdapter; +use crate::oauth::OAuthPersistor; +use crate::oauth::ResolvedOAuthCredentialStore; +use crate::oauth::StoredOAuthTokens; +use crate::oauth::WrappedOAuthTokenResponse; +use crate::oauth::request_oauth_token_response; +use crate::oauth::save_oauth_tokens; +use crate::oauth_http_client::OAuthHttpClientAdapter; + +const SERVER_NAME: &str = "oauth-transport-response-test"; +const SERVER_URL_ENV: &str = "MCP_TEST_OAUTH_RESPONSE_SERVER_URL"; +const ACCESS_TOKEN_A: &str = "response-access-a"; +const REFRESH_TOKEN_A: &str = "response-refresh-a"; +const ACCESS_TOKEN_B: &str = "response-access-b"; +const REFRESH_TOKEN_B: &str = "response-refresh-b"; + +#[tokio::test] +async fn server_response_post_receives_one_shot_oauth_recovery() -> anyhow::Result<()> { + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/.well-known/oauth-authorization-server/mcp")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "authorization_endpoint": format!("{}/oauth/authorize", server.uri()), + "token_endpoint": format!("{}/oauth/token", server.uri()), + "scopes_supported": ["scope-a"], + }))) + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path("/oauth/token")) + .and(body_string_contains("grant_type=refresh_token")) + .and(body_string_contains(format!( + "refresh_token={REFRESH_TOKEN_A}" + ))) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "access_token": ACCESS_TOKEN_B, + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": REFRESH_TOKEN_B, + "scope": "scope-a", + }))) + .expect(1) + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path("/mcp")) + .and(header("authorization", format!("Bearer {ACCESS_TOKEN_A}"))) + .respond_with(ResponseTemplate::new(401)) + .expect(1) + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path("/mcp")) + .and(header("authorization", format!("Bearer {ACCESS_TOKEN_B}"))) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&server) + .await; + + let codex_home = TempDir::new()?; + let status = Command::new(std::env::current_exe()?) + .args([ + "oauth_transport::tests::server_response_post_child", + "--exact", + "--ignored", + "--nocapture", + ]) + .env("CODEX_HOME", codex_home.path()) + .env(SERVER_URL_ENV, format!("{}/mcp", server.uri())) + .status() + .await?; + anyhow::ensure!(status.success(), "OAuth response child failed: {status}"); + server.verify().await; + Ok(()) +} + +#[tokio::test] +#[ignore = "spawned by server_response_post_receives_one_shot_oauth_recovery"] +async fn server_response_post_child() -> anyhow::Result<()> { + let server_url = std::env::var(SERVER_URL_ENV)?; + let initial_tokens = initial_tokens(&server_url); + save_oauth_tokens( + SERVER_NAME, + &initial_tokens, + OAuthCredentialsStoreMode::File, + AuthKeyringBackendKind::default(), + )?; + + let http_client = Environment::default_for_tests().get_http_client(); + let oauth_http_client = Arc::new(OAuthHttpClientAdapter::new( + Arc::clone(&http_client), + HeaderMap::new(), + )); + let mut oauth_state = + OAuthState::new_with_oauth_http_client(server_url.clone(), oauth_http_client).await?; + oauth_state + .set_credentials( + &initial_tokens.client_id, + request_oauth_token_response(&initial_tokens), + ) + .await?; + let manager = match oauth_state { + OAuthState::Authorized(manager) | OAuthState::Unauthorized(manager) => manager, + _ => anyhow::bail!("unexpected OAuth state during response test setup"), + }; + let auth_client = AuthClient::new( + StreamableHttpClientAdapter::new( + Arc::clone(&http_client), + HeaderMap::new(), + /*auth_provider*/ None, + ), + manager, + ); + let persistor = OAuthPersistor::new( + SERVER_NAME.to_string(), + server_url.clone(), + Arc::clone(&auth_client.auth_manager), + ResolvedOAuthCredentialStore::File, + Some(initial_tokens), + ); + let client = OAuthTransportClient::new(auth_client, persistor); + let response_message: ClientJsonRpcMessage = serde_json::from_value(json!({ + "jsonrpc": "2.0", + "id": "server-request-1", + "result": { + "action": "accept", + "content": { "confirmed": true } + } + }))?; + + let response = client + .post_message( + Arc::from(server_url), + response_message, + Some(Arc::from("response-session")), + /*auth_token*/ None, + HashMap::new(), + ) + .await?; + + assert!(matches!(response, StreamableHttpPostResponse::Accepted)); + Ok(()) +} + +fn initial_tokens(server_url: &str) -> StoredOAuthTokens { + let mut response = OAuthTokenResponse::new( + AccessToken::new(ACCESS_TOKEN_A.to_string()), + BasicTokenType::Bearer, + VendorExtraTokenFields::default(), + ); + response.set_refresh_token(Some(RefreshToken::new(REFRESH_TOKEN_A.to_string()))); + response.set_expires_in(None); + StoredOAuthTokens { + server_name: SERVER_NAME.to_string(), + url: server_url.to_string(), + client_id: "test-client-id".to_string(), + token_response: WrappedOAuthTokenResponse(response), + expires_at: None, + } +} diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index 88767310531d..d01f87e3bd72 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -18,6 +18,7 @@ use codex_exec_server::HttpClient; use codex_keyring_store::DefaultKeyringStore; use futures::FutureExt; use futures::future::BoxFuture; +use oauth2::AccessToken; use oauth2::TokenResponse; use reqwest::header::AUTHORIZATION; use reqwest::header::HeaderMap; @@ -71,8 +72,10 @@ use crate::oauth::ResolvedOAuthCredentialStore; use crate::oauth::ResolvedOAuthTokens; use crate::oauth::StoredOAuthTokens; use crate::oauth::load_oauth_tokens_from_store; +use crate::oauth::request_oauth_token_response; use crate::oauth::resolve_oauth_tokens; use crate::oauth_http_client::OAuthHttpClientAdapter; +use crate::oauth_transport::OAuthTransportClient; use crate::stdio_server_launcher::StdioServerCommand; use crate::stdio_server_launcher::StdioServerLauncher; use crate::stdio_server_launcher::StdioServerProcessHandle; @@ -99,7 +102,7 @@ enum PendingTransport { transport: StreamableHttpClientTransport, }, StreamableHttpWithOAuth { - transport: StreamableHttpClientTransport>, + transport: StreamableHttpClientTransport, oauth_persistor: OAuthPersistor, }, } @@ -133,6 +136,7 @@ enum TransportRecipe { store_mode: OAuthCredentialsStoreMode, keyring_backend_kind: AuthKeyringBackendKind, resolved_store: Arc>, + oauth_client: Arc>, http_client: Arc, auth_provider: Option, }, @@ -410,6 +414,7 @@ impl RmcpClient { store_mode, keyring_backend_kind, resolved_store: Arc::new(OnceLock::new()), + oauth_client: Arc::new(OnceLock::new()), http_client, auth_provider, }; @@ -454,7 +459,7 @@ impl RmcpClient { let mut initialize_deadline = timeout.map(|duration| Instant::now() + duration); let (service, oauth_persistor) = self - .connect_pending_transport_with_initialize_retries( + .connect_pending_transport_with_oauth_recovery( pending_transport, client_service.clone(), timeout, @@ -487,12 +492,6 @@ impl RmcpClient { }; } - if let Some(runtime) = oauth_persistor - && let Err(error) = runtime.persist_if_needed().await - { - warn!("failed to persist OAuth tokens after initialize: {error}"); - } - Ok(initialize_result) } @@ -508,7 +507,6 @@ impl RmcpClient { async move { service.list_tools(params).await }.boxed() }) .await?; - self.persist_oauth_tokens().await; Ok(result) } @@ -543,7 +541,6 @@ impl RmcpClient { }) }) .collect::>>()?; - self.persist_oauth_tokens().await; Ok(ListToolsWithConnectorIdResult { next_cursor: result.next_cursor, tools, @@ -570,7 +567,6 @@ impl RmcpClient { async move { service.list_resources(params).await }.boxed() }) .await?; - self.persist_oauth_tokens().await; Ok(result) } @@ -586,7 +582,6 @@ impl RmcpClient { async move { service.list_resource_templates(params).await }.boxed() }) .await?; - self.persist_oauth_tokens().await; Ok(result) } @@ -602,7 +597,6 @@ impl RmcpClient { async move { service.read_resource(params).await }.boxed() }) .await?; - self.persist_oauth_tokens().await; Ok(result) } @@ -660,7 +654,6 @@ impl RmcpClient { .boxed() }) .await?; - self.persist_oauth_tokens().await; Ok(result) } @@ -690,7 +683,6 @@ impl RmcpClient { }, ) .await?; - self.persist_oauth_tokens().await; Ok(()) } @@ -713,14 +705,24 @@ impl RmcpClient { .boxed() }) .await?; - self.persist_oauth_tokens().await; Ok(response) } async fn service(&self) -> Result>> { + self.service_and_oauth_persistor() + .await + .map(|(service, _oauth_persistor)| service) + } + + async fn service_and_oauth_persistor( + &self, + ) -> Result<( + Arc>, + Option, + )> { let guard = self.state.lock().await; match &*guard { - ClientState::Ready { service, .. } => Ok(Arc::clone(service)), + ClientState::Ready { service, oauth } => Ok((Arc::clone(service), oauth.clone())), ClientState::Connecting { .. } => Err(anyhow!("MCP client not initialized")), ClientState::Closed => Err(anyhow!("MCP client is shut down")), } @@ -753,16 +755,6 @@ impl RmcpClient { drop(previous_state); } - /// Preserve refreshes that RMCP may still perform until Codex becomes the sole owner in the - /// next stack layer. - async fn persist_oauth_tokens(&self) { - if let Some(runtime) = self.oauth_persistor().await - && let Err(error) = runtime.persist_if_needed().await - { - warn!("failed to persist OAuth tokens: {error}"); - } - } - async fn refresh_oauth_if_needed(&self) -> Result<()> { if let Some(runtime) = self.oauth_persistor().await { runtime.refresh_if_needed().await?; @@ -791,6 +783,7 @@ impl RmcpClient { store_mode, keyring_backend_kind, resolved_store, + oauth_client, http_client, auth_provider, } => { @@ -803,6 +796,21 @@ impl RmcpClient { auth_provider.clone() }; + // Reuse one OAuth manager and persistor across initialize retries and session + // reconstruction. This preserves the lifecycle-pinned store and keeps each failed + // request paired with the manager snapshot that supplied its access token. + if let Some(oauth_client) = oauth_client.get() { + let runtime = oauth_client.persistor(); + let transport = StreamableHttpClientTransport::with_client( + oauth_client.clone(), + StreamableHttpClientTransportConfig::with_uri(url.clone()), + ); + return Ok(PendingTransport::StreamableHttpWithOAuth { + transport, + oauth_persistor: runtime, + }); + } + let resolved_oauth_tokens = if bearer_token.is_none() && auth_provider.is_none() && !default_headers.contains_key(AUTHORIZATION) @@ -847,7 +855,7 @@ impl RmcpClient { store: credential_store, }) = resolved_oauth_tokens { - match create_oauth_transport_and_runtime( + match create_oauth_transport_client( server_name, url, initial_tokens.clone(), @@ -857,7 +865,19 @@ impl RmcpClient { ) .await { - Ok((transport, oauth_persistor)) => { + Ok(resolved_oauth_client) => { + oauth_client + .set(resolved_oauth_client.clone()) + .map_err(|_| { + anyhow!( + "OAuth client resolved concurrently for MCP server `{server_name}`" + ) + })?; + let oauth_persistor = resolved_oauth_client.persistor(); + let transport = StreamableHttpClientTransport::with_client( + resolved_oauth_client, + StreamableHttpClientTransportConfig::with_uri(url.clone()), + ); Ok(PendingTransport::StreamableHttpWithOAuth { transport, oauth_persistor, @@ -970,19 +990,7 @@ impl RmcpClient { .await .map_err(|source| anyhow::Error::from(HandshakeError { source })), }; - let service = match service_result { - Ok(service) => service, - Err(error) => { - if let Some(runtime) = oauth_persistor.as_ref() - && let Err(persist_error) = runtime.persist_if_needed().await - { - warn!( - "failed to persist OAuth tokens after failed initialize: {persist_error}" - ); - } - return Err(error); - } - }; + let service = service_result?; Ok((Arc::new(service), oauth_persistor)) } @@ -997,38 +1005,94 @@ impl RmcpClient { F: Fn(Arc>) -> Fut, Fut: std::future::Future>, { - let service = self.service().await?; - match Self::run_service_operation_with_transient_retries( + let deadline = timeout.map(|duration| Instant::now() + duration); + // Keep the OAuth persistor paired with the service that performs this operation. Session + // recovery can replace both while the request is in flight; rereading only the persistor + // after a 401 could refresh credentials owned by a different transport lifecycle. + let (service, oauth_persistor) = self.service_and_oauth_persistor().await?; + let mut result = Self::run_service_operation_with_transient_retries( Arc::clone(&service), label, timeout, + deadline, self.elicitation_pause_state.clone(), &operation, ) - .await + .await; + + if let Some(rejected_access_token) = result + .as_ref() + .err() + .and_then(Self::rejected_access_token_from_operation_error) + && let Some(oauth_persistor) = oauth_persistor { - Ok(result) => Ok(result), - Err(error) if Self::is_session_expired_404(&error) => { - self.reinitialize_after_session_expiry(&service).await?; - let recovered_service = self.service().await?; - Self::run_service_operation_with_transient_retries( - recovered_service, - label, - timeout, - self.elicitation_pause_state.clone(), - &operation, - ) - .await - .map_err(Into::into) + // Public request/notification recovery stays here rather than in the transport + // wrapper because this layer owns the caller deadline. RMCP can continue processing a + // queued transport message after the caller times out; retrying it inside the wrapper + // could therefore replay a timed-out tool call. The refresh transaction itself is + // independently owned and completes to its bounded provider timeout if this caller is + // canceled. + let remaining = remaining_operation_timeout(label, timeout, deadline)?; + let refresh = oauth_persistor.refresh_after_unauthorized(rejected_access_token); + let refresh_result = match remaining { + Some(remaining) => match time::timeout(remaining, refresh).await { + Ok(result) => result, + Err(_) => { + // `refresh_after_unauthorized` spawns the credential transaction before it + // waits. Dropping this caller wait therefore detaches the JoinHandle while + // the transaction retains the credential lock and continues through its + // own provider/persistence bounds. The public operation still honors the + // timeout it advertised and does not replay the rejected request later. + return Err(ClientOperationError::Timeout { + label: label.to_string(), + duration: timeout.unwrap_or(remaining), + } + .into()); + } + }, + None => refresh.await, + }; + if let Err(error) = refresh_result { + if let Err(timeout_error) = remaining_operation_timeout(label, timeout, deadline) { + return Err(timeout_error.into()); + } + return Err(error); } - Err(error) => Err(error.into()), + result = Self::run_service_operation_with_transient_retries( + Arc::clone(&service), + label, + timeout, + deadline, + self.elicitation_pause_state.clone(), + &operation, + ) + .await; + } + + if result.as_ref().is_err_and(Self::is_session_expired_404) { + // Session recovery remains one-shot and runs after the optional OAuth retry, so a 401 + // followed by the old session's 404 still reconstructs the transport before retrying. + self.reinitialize_after_session_expiry(&service).await?; + let recovered_service = self.service().await?; + result = Self::run_service_operation_with_transient_retries( + recovered_service, + label, + timeout, + deadline, + self.elicitation_pause_state.clone(), + &operation, + ) + .await; } + + result.map_err(Into::into) } async fn run_service_operation_with_transient_retries( service: Arc>, label: &str, timeout: Option, + retry_deadline: Option, pause_state: ElicitationPauseState, operation: &F, ) -> std::result::Result @@ -1036,7 +1100,6 @@ impl RmcpClient { F: Fn(Arc>) -> Fut, Fut: std::future::Future>, { - let retry_deadline = timeout.map(|duration| Instant::now() + duration); for (attempt, retry_delay_ms) in STREAMABLE_HTTP_RETRY_DELAYS_MS .iter() .copied() @@ -1142,6 +1205,34 @@ impl RmcpClient { }) } + fn rejected_access_token_from_operation_error( + error: &ClientOperationError, + ) -> Option { + let ClientOperationError::Service(rmcp::service::ServiceError::TransportSend(error)) = + error + else { + return None; + }; + + error + .error + .downcast_ref::>() + .and_then(Self::rejected_access_token) + } + + pub(super) fn rejected_access_token( + error: &StreamableHttpError, + ) -> Option { + match error { + StreamableHttpError::Client( + StreamableHttpClientAdapterError::AccessTokenRejected { + rejected_access_token, + }, + ) => Some(rejected_access_token.clone()), + _ => None, + } + } + async fn reinitialize_after_session_expiry( &self, failed_service: &Arc>, @@ -1179,7 +1270,7 @@ impl RmcpClient { .timeout .map(|duration| Instant::now() + duration); let (service, oauth_persistor) = self - .connect_pending_transport_with_initialize_retries( + .connect_pending_transport_with_oauth_recovery( pending_transport, initialize_context.client_service, initialize_context.timeout, @@ -1198,27 +1289,18 @@ impl RmcpClient { }; } - if let Some(runtime) = oauth_persistor - && let Err(error) = runtime.persist_if_needed().await - { - warn!("failed to persist OAuth tokens after session recovery: {error}"); - } - Ok(()) } } -async fn create_oauth_transport_and_runtime( +async fn create_oauth_transport_client( server_name: &str, url: &str, initial_tokens: StoredOAuthTokens, credential_store: ResolvedOAuthCredentialStore, default_headers: HeaderMap, http_client: Arc, -) -> Result<( - StreamableHttpClientTransport>, - OAuthPersistor, -)> { +) -> Result { let oauth_http_client = Arc::new(OAuthHttpClientAdapter::new( http_client.clone(), default_headers.clone(), @@ -1229,7 +1311,7 @@ async fn create_oauth_transport_and_runtime( oauth_state .set_credentials( &initial_tokens.client_id, - initial_tokens.token_response.0.clone(), + request_oauth_token_response(&initial_tokens), ) .await?; @@ -1247,11 +1329,6 @@ async fn create_oauth_transport_and_runtime( ); let auth_manager = auth_client.auth_manager.clone(); - let transport = StreamableHttpClientTransport::with_client( - auth_client, - StreamableHttpClientTransportConfig::with_uri(url.to_string()), - ); - let runtime = OAuthPersistor::new( server_name.to_string(), url.to_string(), @@ -1260,7 +1337,7 @@ async fn create_oauth_transport_and_runtime( Some(initial_tokens), ); - Ok((transport, runtime)) + Ok(OAuthTransportClient::new(auth_client, runtime)) } #[cfg(test)] diff --git a/codex-rs/rmcp-client/src/streamable_http_retry.rs b/codex-rs/rmcp-client/src/streamable_http_retry.rs index 2794489b2a76..540209e55de0 100644 --- a/codex-rs/rmcp-client/src/streamable_http_retry.rs +++ b/codex-rs/rmcp-client/src/streamable_http_retry.rs @@ -5,6 +5,7 @@ use std::time::Instant; use anyhow::Result; use anyhow::anyhow; use codex_exec_server::ExecServerError; +use oauth2::AccessToken; use reqwest::StatusCode; use rmcp::service::RoleClient; use rmcp::service::RunningService; @@ -22,8 +23,13 @@ use super::RmcpClient; const JSON_RPC_INTERNAL_ERROR_CODE: i64 = -32603; pub(super) const STREAMABLE_HTTP_RETRY_DELAYS_MS: [u64; 2] = [250, 1_000]; +#[derive(Default)] +struct InitializeAttemptContext { + oauth_persistor: Option, +} + impl RmcpClient { - pub(super) async fn connect_pending_transport_with_initialize_retries( + pub(super) async fn connect_pending_transport_with_oauth_recovery( &self, initial_transport: PendingTransport, client_service: ElicitationClientService, @@ -32,6 +38,75 @@ impl RmcpClient { ) -> Result<( Arc>, Option, + )> { + let mut attempt_context = InitializeAttemptContext::default(); + match self + .connect_pending_transport_with_initialize_retries( + initial_transport, + client_service.clone(), + timeout, + initialize_deadline, + &mut attempt_context, + ) + .await + { + Ok(result) => Ok(result), + Err(error) => { + let Some(rejected_access_token) = + Self::rejected_access_token_from_initialize_error(&error) + else { + return Err(error); + }; + let Some(oauth_persistor) = attempt_context.oauth_persistor else { + return Err(error); + }; + // Initialization gets one OAuth refresh and one reconstructed transport. Reusing + // this wrapper for the retry would turn persistent 401s into a refresh loop. The + // startup deadline gates whether recovery starts and bounds transport setup plus + // the retry handshake, but the refresh transaction has its own bounds and is + // deliberately excluded from the startup budget. + remaining_initialize_timeout(timeout, *initialize_deadline)?; + let refresh_started_at = Instant::now(); + let refresh_result = oauth_persistor + .refresh_after_unauthorized(rejected_access_token) + .await; + if let Some(deadline) = initialize_deadline.as_mut() { + *deadline += refresh_started_at.elapsed(); + } + refresh_result?; + let remaining = remaining_initialize_timeout(timeout, *initialize_deadline)?; + let transport = match remaining { + Some(remaining) => time::timeout( + remaining, + Self::create_pending_transport(&self.transport_recipe), + ) + .await + .map_err(|_| initialize_timeout_error(timeout, remaining))??, + None => Self::create_pending_transport(&self.transport_recipe).await?, + }; + let mut retry_context = InitializeAttemptContext::default(); + self.connect_pending_transport_with_initialize_retries( + transport, + client_service, + timeout, + initialize_deadline, + &mut retry_context, + ) + .await + } + } + } + + async fn connect_pending_transport_with_initialize_retries( + &self, + initial_transport: PendingTransport, + client_service: ElicitationClientService, + timeout: Option, + initialize_deadline: &mut Option, + attempt_context: &mut InitializeAttemptContext, + ) -> Result<( + Arc>, + Option, )> { let should_retry = match &initial_transport { PendingTransport::InProcess { .. } | PendingTransport::Stdio { .. } => false, @@ -62,6 +137,17 @@ impl RmcpClient { } } }; + // Keep the persistor paired with the transport attempt that returned 401. Rebuilt + // transports reuse the recipe's lifecycle-pinned credential source, and this pairing + // also keeps the authorization manager and snapshot aligned with the failed attempt. + attempt_context.oauth_persistor = match &transport { + PendingTransport::StreamableHttpWithOAuth { + oauth_persistor, .. + } => Some(oauth_persistor.clone()), + PendingTransport::InProcess { .. } + | PendingTransport::Stdio { .. } + | PendingTransport::StreamableHttp { .. } => None, + }; match Self::connect_pending_transport( transport, client_service.clone(), @@ -108,6 +194,33 @@ impl RmcpClient { }) } + fn rejected_access_token_from_initialize_error(error: &anyhow::Error) -> Option { + error.chain().find_map(|source| { + source + .downcast_ref::() + .and_then(|error| { + Self::rejected_access_token_from_client_initialize_error(&error.source) + }) + .or_else(|| { + source + .downcast_ref::() + .and_then(Self::rejected_access_token_from_client_initialize_error) + }) + }) + } + + fn rejected_access_token_from_client_initialize_error( + error: &rmcp::service::ClientInitializeError, + ) -> Option { + match error { + rmcp::service::ClientInitializeError::TransportError { error, .. } => error + .error + .downcast_ref::>() + .and_then(Self::rejected_access_token), + _ => None, + } + } + fn is_retryable_client_initialize_error(error: &rmcp::service::ClientInitializeError) -> bool { match error { rmcp::service::ClientInitializeError::TransportError { error, context } @@ -158,6 +271,10 @@ impl RmcpClient { | StreamableHttpError::ServerDoesNotSupportSse | StreamableHttpError::Deserialize(_) | StreamableHttpError::Client(StreamableHttpClientAdapterError::SessionExpired404) + | StreamableHttpError::Client( + StreamableHttpClientAdapterError::AccessTokenRejected { .. }, + ) + | StreamableHttpError::Client(StreamableHttpClientAdapterError::OAuth(_)) | StreamableHttpError::Client(StreamableHttpClientAdapterError::Header(_)) => false, _ => false, } diff --git a/codex-rs/rmcp-client/tests/streamable_http_oauth_internal.rs b/codex-rs/rmcp-client/tests/streamable_http_oauth_internal.rs new file mode 100644 index 000000000000..dd17949567a1 --- /dev/null +++ b/codex-rs/rmcp-client/tests/streamable_http_oauth_internal.rs @@ -0,0 +1,219 @@ +mod streamable_http_test_support; + +use std::time::Duration; + +use codex_config::types::AuthKeyringBackendKind; +use codex_config::types::OAuthCredentialsStoreMode; +use codex_exec_server::Environment; +use codex_rmcp_client::RmcpClient; +use codex_rmcp_client::StoredOAuthTokens; +use codex_rmcp_client::WrappedOAuthTokenResponse; +use codex_rmcp_client::save_oauth_tokens; +use oauth2::AccessToken; +use oauth2::RefreshToken; +use oauth2::basic::BasicTokenType; +use rmcp::transport::auth::OAuthTokenResponse; +use rmcp::transport::auth::VendorExtraTokenFields; +use serde_json::Value; +use serde_json::json; +use tempfile::TempDir; +use tokio::process::Command; +use wiremock::Mock; +use wiremock::MockServer; +use wiremock::Request; +use wiremock::ResponseTemplate; +use wiremock::matchers::body_string_contains; +use wiremock::matchers::header; +use wiremock::matchers::method; +use wiremock::matchers::path; + +use streamable_http_test_support::initialize_client; + +const SERVER_NAME: &str = "test-streamable-http-oauth-internal"; +const SERVER_URL_ENV: &str = "MCP_TEST_OAUTH_INTERNAL_SERVER_URL"; +const ACCESS_TOKEN_A: &str = "internal-access-a"; +const REFRESH_TOKEN_A: &str = "internal-refresh-a"; +const ACCESS_TOKEN_B: &str = "internal-access-b"; +const REFRESH_TOKEN_B: &str = "internal-refresh-b"; +const ACCESS_TOKEN_C: &str = "internal-access-c"; +const REFRESH_TOKEN_C: &str = "internal-refresh-c"; + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn rmcp_owned_get_and_delete_receive_oauth_recovery() -> anyhow::Result<()> { + let server = MockServer::start().await; + mount_oauth_metadata(&server).await; + mount_refresh(&server, REFRESH_TOKEN_A, ACCESS_TOKEN_B, REFRESH_TOKEN_B).await; + mount_refresh(&server, REFRESH_TOKEN_B, ACCESS_TOKEN_C, REFRESH_TOKEN_C).await; + + Mock::given(method("POST")) + .and(path("/mcp")) + .and(header("authorization", format!("Bearer {ACCESS_TOKEN_A}"))) + .respond_with(|request: &Request| { + let body: Value = request.body_json().expect("valid JSON-RPC request"); + match body.get("method").and_then(Value::as_str) { + Some("initialize") => initialize_response(&body), + Some("notifications/initialized") => ResponseTemplate::new(202), + method => ResponseTemplate::new(400) + .set_body_string(format!("unexpected JSON-RPC method: {method:?}")), + } + }) + .expect(2) + .mount(&server) + .await; + Mock::given(method("GET")) + .and(path("/mcp")) + .and(header("authorization", format!("Bearer {ACCESS_TOKEN_A}"))) + .respond_with(ResponseTemplate::new(401)) + .expect(1) + .mount(&server) + .await; + Mock::given(method("GET")) + .and(path("/mcp")) + .and(header("authorization", format!("Bearer {ACCESS_TOKEN_B}"))) + // A 405 tells RMCP that the optional common SSE stream is unsupported. Reaching this + // response proves that the wrapper retried the RMCP-owned GET with B after refreshing A. + .respond_with(ResponseTemplate::new(405)) + .expect(1) + .mount(&server) + .await; + Mock::given(method("DELETE")) + .and(path("/mcp")) + .and(header("authorization", format!("Bearer {ACCESS_TOKEN_B}"))) + .respond_with(ResponseTemplate::new(401)) + .expect(1) + .mount(&server) + .await; + Mock::given(method("DELETE")) + .and(path("/mcp")) + .and(header("authorization", format!("Bearer {ACCESS_TOKEN_C}"))) + .respond_with(ResponseTemplate::new(204)) + .expect(1) + .mount(&server) + .await; + + run_child( + "oauth_internal_get_delete_child", + &format!("{}/mcp", server.uri()), + ) + .await?; + server.verify().await; + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +#[ignore = "spawned by rmcp_owned_get_and_delete_receive_oauth_recovery"] +async fn oauth_internal_get_delete_child() -> anyhow::Result<()> { + let client = create_oauth_client().await?; + initialize_client(&client).await?; + tokio::time::sleep(Duration::from_millis(/*millis*/ 500)).await; + client.shutdown().await; + tokio::time::sleep(Duration::from_millis(/*millis*/ 750)).await; + Ok(()) +} + +async fn create_oauth_client() -> anyhow::Result { + let server_url = std::env::var(SERVER_URL_ENV)?; + save_initial_tokens(&server_url)?; + RmcpClient::new_streamable_http_client( + SERVER_NAME, + &server_url, + /*bearer_token*/ None, + /*http_headers*/ None, + /*env_http_headers*/ None, + OAuthCredentialsStoreMode::File, + AuthKeyringBackendKind::default(), + Environment::default_for_tests().get_http_client(), + /*auth_provider*/ None, + ) + .await +} + +fn save_initial_tokens(server_url: &str) -> anyhow::Result<()> { + let mut response = OAuthTokenResponse::new( + AccessToken::new(ACCESS_TOKEN_A.to_string()), + BasicTokenType::Bearer, + VendorExtraTokenFields::default(), + ); + response.set_refresh_token(Some(RefreshToken::new(REFRESH_TOKEN_A.to_string()))); + response.set_expires_in(None); + save_oauth_tokens( + SERVER_NAME, + &StoredOAuthTokens { + server_name: SERVER_NAME.to_string(), + url: server_url.to_string(), + client_id: "test-client-id".to_string(), + token_response: WrappedOAuthTokenResponse(response), + expires_at: None, + }, + OAuthCredentialsStoreMode::File, + AuthKeyringBackendKind::default(), + ) +} + +async fn mount_oauth_metadata(server: &MockServer) { + Mock::given(method("GET")) + .and(path("/.well-known/oauth-authorization-server/mcp")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "authorization_endpoint": format!("{}/oauth/authorize", server.uri()), + "token_endpoint": format!("{}/oauth/token", server.uri()), + "scopes_supported": ["scope-a"], + }))) + .mount(server) + .await; +} + +async fn mount_refresh( + server: &MockServer, + request_refresh_token: &str, + response_access_token: &str, + response_refresh_token: &str, +) { + Mock::given(method("POST")) + .and(path("/oauth/token")) + .and(body_string_contains("grant_type=refresh_token")) + .and(body_string_contains(format!( + "refresh_token={request_refresh_token}" + ))) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "access_token": response_access_token, + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": response_refresh_token, + "scope": "scope-a", + }))) + .expect(1) + .mount(server) + .await; +} + +fn initialize_response(body: &Value) -> ResponseTemplate { + ResponseTemplate::new(200) + .insert_header("mcp-session-id", "oauth-internal-session") + .set_body_json(json!({ + "jsonrpc": "2.0", + "id": body.get("id").cloned().unwrap_or(Value::Null), + "result": { + "protocolVersion": body + .pointer("/params/protocolVersion") + .cloned() + .unwrap_or_else(|| json!("2025-06-18")), + "capabilities": {}, + "serverInfo": { + "name": "oauth-internal-test", + "version": "0.0.0-test" + } + } + })) +} + +async fn run_child(test_name: &str, server_url: &str) -> anyhow::Result<()> { + let codex_home = TempDir::new()?; + let status = Command::new(std::env::current_exe()?) + .args([test_name, "--exact", "--ignored", "--nocapture"]) + .env("CODEX_HOME", codex_home.path()) + .env(SERVER_URL_ENV, server_url) + .status() + .await?; + anyhow::ensure!(status.success(), "OAuth internal child failed: {status}"); + Ok(()) +} diff --git a/codex-rs/rmcp-client/tests/streamable_http_oauth_startup.rs b/codex-rs/rmcp-client/tests/streamable_http_oauth_startup.rs index 32f90bc41e8e..c0d582736aa1 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_oauth_startup.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_oauth_startup.rs @@ -1,6 +1,7 @@ mod streamable_http_test_support; use std::time::Duration; +use std::time::Instant; use std::time::SystemTime; use std::time::UNIX_EPOCH; @@ -40,6 +41,9 @@ const SERVER_NAME: &str = "test-streamable-http-oauth-startup"; const EXPIRED_ACCESS_TOKEN: &str = "expired-access-token"; const REFRESH_TOKEN: &str = "valid-refresh-token"; const REFRESHED_ACCESS_TOKEN: &str = "refreshed-access-token"; +const ROTATED_REFRESH_TOKEN: &str = "rotated-refresh-token"; +const FINAL_ACCESS_TOKEN: &str = "final-access-token"; +const FINAL_REFRESH_TOKEN: &str = "final-refresh-token"; const CHILD_SERVER_URL_ENV: &str = "MCP_TEST_OAUTH_STARTUP_SERVER_URL"; const UNREFRESHABLE_SERVER_URL: &str = "https://unrefreshable.example/mcp"; const UNEXPIRED_SERVER_URL: &str = "https://unexpired.example/mcp"; @@ -123,6 +127,185 @@ async fn refreshes_expired_persisted_token_before_initialize() -> anyhow::Result Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn recovers_initialization_and_operation_401_once() -> anyhow::Result<()> { + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/.well-known/oauth-authorization-server/mcp")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "authorization_endpoint": format!("{}/oauth/authorize", server.uri()), + "token_endpoint": format!("{}/oauth/token", server.uri()), + "scopes_supported": ["scope-a"], + }))) + .expect(1) + .mount(&server) + .await; + mount_refresh( + &server, + REFRESH_TOKEN, + REFRESHED_ACCESS_TOKEN, + ROTATED_REFRESH_TOKEN, + ) + .await; + mount_refresh( + &server, + ROTATED_REFRESH_TOKEN, + FINAL_ACCESS_TOKEN, + FINAL_REFRESH_TOKEN, + ) + .await; + + Mock::given(method("POST")) + .and(path("/mcp")) + .and(header( + "authorization", + format!("Bearer {EXPIRED_ACCESS_TOKEN}"), + )) + .respond_with(ResponseTemplate::new(401)) + .expect(1) + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path("/mcp")) + .and(header( + "authorization", + format!("Bearer {REFRESHED_ACCESS_TOKEN}"), + )) + .respond_with(|request: &Request| { + let body: Value = request.body_json().expect("valid JSON-RPC request"); + match body.get("method").and_then(Value::as_str) { + Some("initialize") => initialize_response(&body), + Some("notifications/initialized") => ResponseTemplate::new(202), + Some("tools/list") => ResponseTemplate::new(401), + method => ResponseTemplate::new(400) + .set_body_string(format!("unexpected JSON-RPC method: {method:?}")), + } + }) + .expect(3) + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path("/mcp")) + .and(header( + "authorization", + format!("Bearer {FINAL_ACCESS_TOKEN}"), + )) + .respond_with(|request: &Request| { + let body: Value = request.body_json().expect("valid JSON-RPC request"); + ResponseTemplate::new(200).set_body_json(json!({ + "jsonrpc": "2.0", + "id": body.get("id").cloned().unwrap_or(Value::Null), + "result": { "tools": [] }, + })) + }) + .expect(1) + .mount(&server) + .await; + + let codex_home = TempDir::new()?; + let status = Command::new(std::env::current_exe()?) + .args([ + "oauth_401_recovery_child", + "--exact", + "--ignored", + "--nocapture", + ]) + .env("CODEX_HOME", codex_home.path()) + .env(CHILD_SERVER_URL_ENV, format!("{}/mcp", server.uri())) + .status() + .await?; + assert!(status.success(), "OAuth recovery child failed: {status}"); + server.verify().await; + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn operation_timeout_bounds_unauthorized_refresh_wait() -> anyhow::Result<()> { + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/.well-known/oauth-authorization-server/mcp")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "authorization_endpoint": format!("{}/oauth/authorize", server.uri()), + "token_endpoint": format!("{}/oauth/token", server.uri()), + "scopes_supported": ["scope-a"], + }))) + .expect(1) + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path("/oauth/token")) + .and(body_string_contains("grant_type=refresh_token")) + .and(body_string_contains(format!( + "refresh_token={REFRESH_TOKEN}" + ))) + .respond_with( + ResponseTemplate::new(200) + .set_delay(Duration::from_millis(/*millis*/ 500)) + .set_body_json(json!({ + "access_token": REFRESHED_ACCESS_TOKEN, + "token_type": "Bearer", + "expires_in": 7200, + "refresh_token": ROTATED_REFRESH_TOKEN, + "scope": "scope-a", + })), + ) + .expect(1) + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path("/mcp")) + .and(header( + "authorization", + format!("Bearer {EXPIRED_ACCESS_TOKEN}"), + )) + .respond_with(|request: &Request| { + let body: Value = request.body_json().expect("valid JSON-RPC request"); + match body.get("method").and_then(Value::as_str) { + Some("initialize") => initialize_response(&body), + Some("notifications/initialized") => ResponseTemplate::new(202), + Some("tools/list") => ResponseTemplate::new(401), + method => ResponseTemplate::new(400) + .set_body_string(format!("unexpected JSON-RPC method: {method:?}")), + } + }) + .expect(4) + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path("/mcp")) + .and(header( + "authorization", + format!("Bearer {REFRESHED_ACCESS_TOKEN}"), + )) + .respond_with(|request: &Request| { + let body: Value = request.body_json().expect("valid JSON-RPC request"); + ResponseTemplate::new(200).set_body_json(json!({ + "jsonrpc": "2.0", + "id": body.get("id").cloned().unwrap_or(Value::Null), + "result": { "tools": [] }, + })) + }) + .expect(1) + .mount(&server) + .await; + + let codex_home = TempDir::new()?; + let status = Command::new(std::env::current_exe()?) + .args([ + "oauth_401_timeout_child", + "--exact", + "--ignored", + "--nocapture", + ]) + .env("CODEX_HOME", codex_home.path()) + .env(CHILD_SERVER_URL_ENV, format!("{}/mcp", server.uri())) + .status() + .await?; + assert!(status.success(), "OAuth timeout child failed: {status}"); + server.verify().await; + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn reports_auth_status_for_persisted_credentials() -> anyhow::Result<()> { let codex_home = TempDir::new()?; @@ -236,7 +419,9 @@ async fn persisted_credentials_auth_status_child() -> anyhow::Result<()> { url: UNEXPIRED_SERVER_URL.to_string(), client_id: "test-client-id".to_string(), token_response: WrappedOAuthTokenResponse(response), - expires_at: Some(now.saturating_add(/*rhs*/ 60_000)), + // Keep this outside the 60-second proactive refresh guard band. The test is checking a + // healthy persisted access token, not the boundary where a refresh becomes necessary. + expires_at: Some(now.saturating_add(/*rhs*/ 120_000)), }; save_oauth_tokens( SERVER_NAME, @@ -376,3 +561,129 @@ async fn expired_unrefreshable_startup_child() -> anyhow::Result<()> { assert!(is_authentication_required_error(&error)); Ok(()) } + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +#[ignore = "spawned by recovers_initialization_and_operation_401_once"] +async fn oauth_401_recovery_child() -> anyhow::Result<()> { + let server_url = std::env::var(CHILD_SERVER_URL_ENV)?; + let client = refreshable_oauth_client(&server_url).await?; + initialize_client(&client).await?; + let tools = client + .list_tools(/*params*/ None, Some(Duration::from_secs(/*secs*/ 5))) + .await?; + assert!(tools.tools.is_empty()); + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +#[ignore = "spawned by operation_timeout_bounds_unauthorized_refresh_wait"] +async fn oauth_401_timeout_child() -> anyhow::Result<()> { + let server_url = std::env::var(CHILD_SERVER_URL_ENV)?; + let client = refreshable_oauth_client(&server_url).await?; + initialize_client(&client).await?; + + let started_at = Instant::now(); + let error = client + .list_tools( + /*params*/ None, + Some(Duration::from_millis(/*millis*/ 50)), + ) + .await + .expect_err("operation deadline should expire before the delayed refresh"); + assert!( + error.to_string().contains("timed out awaiting tools/list"), + "unexpected operation error: {error:#}" + ); + assert!( + started_at.elapsed() < Duration::from_millis(/*millis*/ 400), + "operation waited for the OAuth provider instead of its own deadline" + ); + + // The caller stopped waiting, but the owned refresh transaction must finish and update the + // shared manager. A later operation should then use the refreshed token without another + // provider request. + tokio::time::sleep(Duration::from_millis(/*millis*/ 600)).await; + let tools = client + .list_tools(/*params*/ None, Some(Duration::from_secs(/*secs*/ 5))) + .await?; + assert!(tools.tools.is_empty()); + Ok(()) +} + +async fn refreshable_oauth_client(server_url: &str) -> anyhow::Result { + let mut response = OAuthTokenResponse::new( + AccessToken::new(EXPIRED_ACCESS_TOKEN.to_string()), + BasicTokenType::Bearer, + VendorExtraTokenFields::default(), + ); + response.set_refresh_token(Some(RefreshToken::new(REFRESH_TOKEN.to_string()))); + response.set_expires_in(None); + save_oauth_tokens( + SERVER_NAME, + &StoredOAuthTokens { + server_name: SERVER_NAME.to_string(), + url: server_url.to_string(), + client_id: "test-client-id".to_string(), + token_response: WrappedOAuthTokenResponse(response), + expires_at: None, + }, + OAuthCredentialsStoreMode::File, + AuthKeyringBackendKind::default(), + )?; + + let client = RmcpClient::new_streamable_http_client( + SERVER_NAME, + server_url, + /*bearer_token*/ None, + /*http_headers*/ None, + /*env_http_headers*/ None, + OAuthCredentialsStoreMode::File, + AuthKeyringBackendKind::default(), + Environment::default_for_tests().get_http_client(), + /*auth_provider*/ None, + ) + .await?; + Ok(client) +} + +async fn mount_refresh( + server: &MockServer, + request_refresh_token: &str, + response_access_token: &str, + response_refresh_token: &str, +) { + Mock::given(method("POST")) + .and(path("/oauth/token")) + .and(body_string_contains("grant_type=refresh_token")) + .and(body_string_contains(format!( + "refresh_token={request_refresh_token}" + ))) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "access_token": response_access_token, + "token_type": "Bearer", + "expires_in": 7200, + "refresh_token": response_refresh_token, + "scope": "scope-a", + }))) + .expect(1) + .mount(server) + .await; +} + +fn initialize_response(body: &Value) -> ResponseTemplate { + ResponseTemplate::new(200).set_body_json(json!({ + "jsonrpc": "2.0", + "id": body.get("id").cloned().unwrap_or(Value::Null), + "result": { + "protocolVersion": body + .pointer("/params/protocolVersion") + .cloned() + .unwrap_or_else(|| json!("2025-06-18")), + "capabilities": {}, + "serverInfo": { + "name": "oauth-401-recovery-test", + "version": "0.0.0-test", + }, + }, + })) +} From 8d9e34b3572b37d38cb529d45cb5a059788ecb06 Mon Sep 17 00:00:00 2001 From: Steven Lee Date: Sat, 27 Jun 2026 18:14:06 +0000 Subject: [PATCH 2/4] Stabilize MCP OAuth timeout race test --- codex-rs/rmcp-client/tests/streamable_http_oauth_startup.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/codex-rs/rmcp-client/tests/streamable_http_oauth_startup.rs b/codex-rs/rmcp-client/tests/streamable_http_oauth_startup.rs index c0d582736aa1..ab52d91a7985 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_oauth_startup.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_oauth_startup.rs @@ -268,7 +268,11 @@ async fn operation_timeout_bounds_unauthorized_refresh_wait() -> anyhow::Result< .set_body_string(format!("unexpected JSON-RPC method: {method:?}")), } }) - .expect(4) + // The later operation can either observe the completed background refresh immediately + // (three requests with the old token) or race it once before adopting the refreshed + // credentials (four). The exact provider-refresh and refreshed-token expectations below + // still require both paths to converge after one rotation. + .expect(3..=4) .mount(&server) .await; Mock::given(method("POST")) From 73e4bbcd7cc2d88783541b3cfe7036d97f8f6037 Mon Sep 17 00:00:00 2001 From: Steven Lee Date: Sun, 28 Jun 2026 07:30:51 +0000 Subject: [PATCH 3/4] Require reauthentication after rejected OAuth retry --- codex-rs/rmcp-client/src/oauth_transport.rs | 38 ++++-- .../rmcp-client/src/oauth_transport_tests.rs | 18 +++ codex-rs/rmcp-client/src/rmcp_client.rs | 10 ++ .../rmcp-client/src/streamable_http_retry.rs | 30 +++-- .../tests/streamable_http_oauth_startup.rs | 120 +++++++++++++++++- 5 files changed, 194 insertions(+), 22 deletions(-) diff --git a/codex-rs/rmcp-client/src/oauth_transport.rs b/codex-rs/rmcp-client/src/oauth_transport.rs index b6e9556de888..92d7ca20e59e 100644 --- a/codex-rs/rmcp-client/src/oauth_transport.rs +++ b/codex-rs/rmcp-client/src/oauth_transport.rs @@ -20,6 +20,7 @@ use reqwest::header::HeaderValue; use rmcp::model::ClientJsonRpcMessage; use rmcp::model::JsonRpcMessage; use rmcp::transport::auth::AuthClient; +use rmcp::transport::auth::AuthError; use rmcp::transport::streamable_http_client::StreamableHttpClient; use rmcp::transport::streamable_http_client::StreamableHttpError; use rmcp::transport::streamable_http_client::StreamableHttpPostResponse; @@ -126,9 +127,11 @@ impl StreamableHttpClient for OAuthTransportClient { .recover_after_unauthorized("post_message", rejected_access_token) .await? { - self.auth_client - .post_message(uri, message, session_id, auth_token, custom_headers) - .await + authorization_required_after_retry( + self.auth_client + .post_message(uri, message, session_id, auth_token, custom_headers) + .await, + ) } else { result } @@ -156,9 +159,11 @@ impl StreamableHttpClient for OAuthTransportClient { .recover_after_unauthorized("delete_session", rejected_access_token) .await? { - self.auth_client - .delete_session(uri, session_id, auth_token, custom_headers) - .await + authorization_required_after_retry( + self.auth_client + .delete_session(uri, session_id, auth_token, custom_headers) + .await, + ) } else { result } @@ -190,15 +195,30 @@ impl StreamableHttpClient for OAuthTransportClient { .recover_after_unauthorized("get_stream", rejected_access_token) .await? { - self.auth_client - .get_stream(uri, session_id, last_event_id, auth_token, custom_headers) - .await + authorization_required_after_retry( + self.auth_client + .get_stream(uri, session_id, last_event_id, auth_token, custom_headers) + .await, + ) } else { result } } } +fn authorization_required_after_retry(result: TransportResult) -> TransportResult { + match result { + // The first 401 carries the token that was actually rejected so concurrent recovery can + // distinguish A from a newer B. Once the single retry also rejects B, attribution is no + // longer useful: surface the existing reauthentication marker instead of leaking the + // adapter-only error past the Codex-owned recovery boundary. + Err(StreamableHttpError::Client( + StreamableHttpClientAdapterError::AccessTokenRejected { .. }, + )) => Err(StreamableHttpError::Auth(AuthError::AuthorizationRequired)), + result => result, + } +} + fn rejected_access_token( error: &StreamableHttpError, ) -> Option { diff --git a/codex-rs/rmcp-client/src/oauth_transport_tests.rs b/codex-rs/rmcp-client/src/oauth_transport_tests.rs index 31b9e759b70d..de403249536b 100644 --- a/codex-rs/rmcp-client/src/oauth_transport_tests.rs +++ b/codex-rs/rmcp-client/src/oauth_transport_tests.rs @@ -10,10 +10,12 @@ use oauth2::basic::BasicTokenType; use reqwest::header::HeaderMap; use rmcp::model::ClientJsonRpcMessage; use rmcp::transport::auth::AuthClient; +use rmcp::transport::auth::AuthError; use rmcp::transport::auth::OAuthState; use rmcp::transport::auth::OAuthTokenResponse; use rmcp::transport::auth::VendorExtraTokenFields; use rmcp::transport::streamable_http_client::StreamableHttpClient; +use rmcp::transport::streamable_http_client::StreamableHttpError; use rmcp::transport::streamable_http_client::StreamableHttpPostResponse; use serde_json::json; use tempfile::TempDir; @@ -27,7 +29,9 @@ use wiremock::matchers::method; use wiremock::matchers::path; use super::OAuthTransportClient; +use super::authorization_required_after_retry; use crate::http_client_adapter::StreamableHttpClientAdapter; +use crate::http_client_adapter::StreamableHttpClientAdapterError; use crate::oauth::OAuthPersistor; use crate::oauth::ResolvedOAuthCredentialStore; use crate::oauth::StoredOAuthTokens; @@ -43,6 +47,20 @@ const REFRESH_TOKEN_A: &str = "response-refresh-a"; const ACCESS_TOKEN_B: &str = "response-access-b"; const REFRESH_TOKEN_B: &str = "response-refresh-b"; +#[test] +fn exhausted_transport_retry_requires_reauthentication() { + let result = authorization_required_after_retry::<()>(Err(StreamableHttpError::Client( + StreamableHttpClientAdapterError::AccessTokenRejected { + rejected_access_token: AccessToken::new(ACCESS_TOKEN_B.to_string()), + }, + ))); + + assert!(matches!( + result, + Err(StreamableHttpError::Auth(AuthError::AuthorizationRequired)) + )); +} + #[tokio::test] async fn server_response_post_receives_one_shot_oauth_recovery() -> anyhow::Result<()> { let server = MockServer::start().await; diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index bfd9274ab7c1..1006a5bd5d3d 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -1050,6 +1050,16 @@ impl RmcpClient { &operation, ) .await; + if result + .as_ref() + .err() + .and_then(Self::rejected_access_token_from_operation_error) + .is_some() + { + // The rejected token is needed only to attribute the first 401. A second 401 + // after the one allowed refresh means this lifecycle needs reauthentication. + return Err(AuthError::AuthorizationRequired.into()); + } } if result.as_ref().is_err_and(Self::is_session_expired_404) { diff --git a/codex-rs/rmcp-client/src/streamable_http_retry.rs b/codex-rs/rmcp-client/src/streamable_http_retry.rs index 7acd9e4269cf..9bff1d548ad4 100644 --- a/codex-rs/rmcp-client/src/streamable_http_retry.rs +++ b/codex-rs/rmcp-client/src/streamable_http_retry.rs @@ -9,6 +9,7 @@ use oauth2::AccessToken; use reqwest::StatusCode; use rmcp::service::RoleClient; use rmcp::service::RunningService; +use rmcp::transport::auth::AuthError; use rmcp::transport::streamable_http_client::StreamableHttpError; use tokio::time; use tracing::warn; @@ -85,14 +86,27 @@ impl RmcpClient { None => Self::create_pending_transport(&self.transport_recipe).await?, }; let mut retry_context = InitializeAttemptContext::default(); - self.connect_pending_transport_with_initialize_retries( - transport, - client_service, - timeout, - &mut initialize_deadline, - &mut retry_context, - ) - .await + let result = self + .connect_pending_transport_with_initialize_retries( + transport, + client_service, + timeout, + &mut initialize_deadline, + &mut retry_context, + ) + .await; + if result + .as_ref() + .err() + .and_then(Self::rejected_access_token_from_initialize_error) + .is_some() + { + // The first 401 identifies which access token failed. If the reconstructed + // transport still rejects the refreshed token, preserve Codex's established + // signal that the user must authenticate again. + return Err(AuthError::AuthorizationRequired.into()); + } + result } } } diff --git a/codex-rs/rmcp-client/tests/streamable_http_oauth_startup.rs b/codex-rs/rmcp-client/tests/streamable_http_oauth_startup.rs index ab52d91a7985..d8ea63914e07 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_oauth_startup.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_oauth_startup.rs @@ -44,6 +44,8 @@ const REFRESHED_ACCESS_TOKEN: &str = "refreshed-access-token"; const ROTATED_REFRESH_TOKEN: &str = "rotated-refresh-token"; const FINAL_ACCESS_TOKEN: &str = "final-access-token"; const FINAL_REFRESH_TOKEN: &str = "final-refresh-token"; +const REJECTED_RETRY_ACCESS_TOKEN: &str = "rejected-retry-access-token"; +const REJECTED_RETRY_REFRESH_TOKEN: &str = "rejected-retry-refresh-token"; const CHILD_SERVER_URL_ENV: &str = "MCP_TEST_OAUTH_STARTUP_SERVER_URL"; const UNREFRESHABLE_SERVER_URL: &str = "https://unrefreshable.example/mcp"; const UNEXPIRED_SERVER_URL: &str = "https://unexpired.example/mcp"; @@ -154,6 +156,13 @@ async fn recovers_initialization_and_operation_401_once() -> anyhow::Result<()> FINAL_REFRESH_TOKEN, ) .await; + mount_refresh( + &server, + FINAL_REFRESH_TOKEN, + REJECTED_RETRY_ACCESS_TOKEN, + REJECTED_RETRY_REFRESH_TOKEN, + ) + .await; Mock::given(method("POST")) .and(path("/mcp")) @@ -192,12 +201,30 @@ async fn recovers_initialization_and_operation_401_once() -> anyhow::Result<()> )) .respond_with(|request: &Request| { let body: Value = request.body_json().expect("valid JSON-RPC request"); - ResponseTemplate::new(200).set_body_json(json!({ - "jsonrpc": "2.0", - "id": body.get("id").cloned().unwrap_or(Value::Null), - "result": { "tools": [] }, - })) + match body.get("method").and_then(Value::as_str) { + Some("tools/list") => ResponseTemplate::new(200).set_body_json(json!({ + "jsonrpc": "2.0", + "id": body.get("id").cloned().unwrap_or(Value::Null), + "result": { "tools": [] }, + })), + Some("resources/list") => ResponseTemplate::new(401) + .insert_header("www-authenticate", "Bearer realm=\"mcp\""), + method => ResponseTemplate::new(400) + .set_body_string(format!("unexpected JSON-RPC method: {method:?}")), + } }) + .expect(2) + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path("/mcp")) + .and(header( + "authorization", + format!("Bearer {REJECTED_RETRY_ACCESS_TOKEN}"), + )) + .respond_with( + ResponseTemplate::new(401).insert_header("www-authenticate", "Bearer realm=\"mcp\""), + ) .expect(1) .mount(&server) .await; @@ -219,6 +246,71 @@ async fn recovers_initialization_and_operation_401_once() -> anyhow::Result<()> Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn rejected_initialize_retry_requires_reauthentication() -> anyhow::Result<()> { + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/.well-known/oauth-authorization-server/mcp")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "authorization_endpoint": format!("{}/oauth/authorize", server.uri()), + "token_endpoint": format!("{}/oauth/token", server.uri()), + "scopes_supported": ["scope-a"], + }))) + .expect(1) + .mount(&server) + .await; + mount_refresh( + &server, + REFRESH_TOKEN, + REFRESHED_ACCESS_TOKEN, + ROTATED_REFRESH_TOKEN, + ) + .await; + Mock::given(method("POST")) + .and(path("/mcp")) + .and(header( + "authorization", + format!("Bearer {EXPIRED_ACCESS_TOKEN}"), + )) + .respond_with( + ResponseTemplate::new(401).insert_header("www-authenticate", "Bearer realm=\"mcp\""), + ) + .expect(1) + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path("/mcp")) + .and(header( + "authorization", + format!("Bearer {REFRESHED_ACCESS_TOKEN}"), + )) + .respond_with( + ResponseTemplate::new(401).insert_header("www-authenticate", "Bearer realm=\"mcp\""), + ) + .expect(1) + .mount(&server) + .await; + + let codex_home = TempDir::new()?; + let status = Command::new(std::env::current_exe()?) + .args([ + "oauth_rejected_initialize_retry_child", + "--exact", + "--ignored", + "--nocapture", + ]) + .env("CODEX_HOME", codex_home.path()) + .env(CHILD_SERVER_URL_ENV, format!("{}/mcp", server.uri())) + .status() + .await?; + assert!( + status.success(), + "OAuth rejected-retry child failed: {status}" + ); + server.verify().await; + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn operation_timeout_bounds_unauthorized_refresh_wait() -> anyhow::Result<()> { let server = MockServer::start().await; @@ -576,6 +668,24 @@ async fn oauth_401_recovery_child() -> anyhow::Result<()> { .list_tools(/*params*/ None, Some(Duration::from_secs(/*secs*/ 5))) .await?; assert!(tools.tools.is_empty()); + + let error = client + .list_resources(/*params*/ None, Some(Duration::from_secs(/*secs*/ 5))) + .await + .expect_err("a rejected one-shot OAuth retry should require reauthentication"); + assert!(is_authentication_required_error(&error)); + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +#[ignore = "spawned by rejected_initialize_retry_requires_reauthentication"] +async fn oauth_rejected_initialize_retry_child() -> anyhow::Result<()> { + let server_url = std::env::var(CHILD_SERVER_URL_ENV)?; + let client = refreshable_oauth_client(&server_url).await?; + let error = initialize_client(&client) + .await + .expect_err("a rejected initialize retry should require reauthentication"); + assert!(is_authentication_required_error(&error)); Ok(()) } From 4d735e681a8a85b8d92725761376d450b6d79fee Mon Sep 17 00:00:00 2001 From: Steven Lee Date: Sun, 28 Jun 2026 08:01:53 +0000 Subject: [PATCH 4/4] Compose MCP transport recovery boundaries --- .../rmcp-client/src/http_client_adapter.rs | 17 ++- codex-rs/rmcp-client/src/oauth_transport.rs | 13 ++ .../rmcp-client/src/oauth_transport_tests.rs | 15 +- codex-rs/rmcp-client/src/rmcp_client.rs | 138 ++++++++---------- codex-rs/rmcp-client/src/startup_error.rs | 2 +- .../tests/streamable_http_oauth_startup.rs | 69 +++++---- .../tests/streamable_http_recovery.rs | 41 +++++- .../tests/streamable_http_test_support.rs | 2 + 8 files changed, 186 insertions(+), 111 deletions(-) diff --git a/codex-rs/rmcp-client/src/http_client_adapter.rs b/codex-rs/rmcp-client/src/http_client_adapter.rs index a7d934134812..eccce8c03dfd 100644 --- a/codex-rs/rmcp-client/src/http_client_adapter.rs +++ b/codex-rs/rmcp-client/src/http_client_adapter.rs @@ -56,6 +56,7 @@ pub(crate) struct StreamableHttpClientAdapter { http_client: Arc, default_headers: HeaderMap, auth_provider: Option, + attribute_rejected_access_token: bool, } #[derive(Debug, thiserror::Error)] @@ -82,8 +83,15 @@ impl StreamableHttpClientAdapter { http_client, default_headers, auth_provider, + attribute_rejected_access_token: false, } } + + /// Preserves the access token associated with a 401 for Codex-owned OAuth recovery. + pub(crate) fn with_rejected_token_attribution(mut self) -> Self { + self.attribute_rejected_access_token = true; + self + } } impl StreamableHttpClient for StreamableHttpClientAdapter { @@ -167,7 +175,8 @@ impl StreamableHttpClient for StreamableHttpClientAdapter { StreamableHttpClientAdapterError::SessionExpired404, )); } - if response.status == StatusCode::UNAUTHORIZED.as_u16() + if self.attribute_rejected_access_token + && response.status == StatusCode::UNAUTHORIZED.as_u16() && let Some(error) = access_token_rejected(auth_token.as_deref()) { return Err(error); @@ -284,7 +293,8 @@ impl StreamableHttpClient for StreamableHttpClientAdapter { if response.status == StatusCode::METHOD_NOT_ALLOWED.as_u16() { return Ok(()); } - if response.status == StatusCode::UNAUTHORIZED.as_u16() + if self.attribute_rejected_access_token + && response.status == StatusCode::UNAUTHORIZED.as_u16() && let Some(error) = access_token_rejected(auth_token.as_deref()) { return Err(error); @@ -364,7 +374,8 @@ impl StreamableHttpClient for StreamableHttpClientAdapter { StreamableHttpClientAdapterError::SessionExpired404, )); } - if response.status == StatusCode::UNAUTHORIZED.as_u16() + if self.attribute_rejected_access_token + && response.status == StatusCode::UNAUTHORIZED.as_u16() && let Some(error) = access_token_rejected(auth_token.as_deref()) { return Err(error); diff --git a/codex-rs/rmcp-client/src/oauth_transport.rs b/codex-rs/rmcp-client/src/oauth_transport.rs index 92d7ca20e59e..e5ea3fc2f1f5 100644 --- a/codex-rs/rmcp-client/src/oauth_transport.rs +++ b/codex-rs/rmcp-client/src/oauth_transport.rs @@ -233,6 +233,19 @@ fn rejected_access_token( fn oauth_transport_error( error: anyhow::Error, ) -> StreamableHttpError { + if let Some(auth_error) = + error + .chain() + .find_map(|source| match source.downcast_ref::() { + Some(AuthError::AuthorizationRequired) => Some(AuthError::AuthorizationRequired), + Some(AuthError::TokenExpired) => Some(AuthError::TokenExpired), + _ => None, + }) + { + // Preserve RMCP's established reauthentication variants across Codex's transport policy + // boundary. Other OAuth failures retain their context-rich adapter error. + return StreamableHttpError::Auth(auth_error); + } StreamableHttpError::Client(StreamableHttpClientAdapterError::OAuth(error)) } diff --git a/codex-rs/rmcp-client/src/oauth_transport_tests.rs b/codex-rs/rmcp-client/src/oauth_transport_tests.rs index de403249536b..acbca36c10bc 100644 --- a/codex-rs/rmcp-client/src/oauth_transport_tests.rs +++ b/codex-rs/rmcp-client/src/oauth_transport_tests.rs @@ -30,6 +30,7 @@ use wiremock::matchers::path; use super::OAuthTransportClient; use super::authorization_required_after_retry; +use super::oauth_transport_error; use crate::http_client_adapter::StreamableHttpClientAdapter; use crate::http_client_adapter::StreamableHttpClientAdapterError; use crate::oauth::OAuthPersistor; @@ -61,6 +62,17 @@ fn exhausted_transport_retry_requires_reauthentication() { )); } +#[test] +fn oauth_transport_preserves_reauthentication_errors() { + let error = anyhow::Error::new(AuthError::AuthorizationRequired) + .context("refreshing rejected MCP access token"); + + assert!(matches!( + oauth_transport_error(error), + StreamableHttpError::Auth(AuthError::AuthorizationRequired) + )); +} + #[tokio::test] async fn server_response_post_receives_one_shot_oauth_recovery() -> anyhow::Result<()> { let server = MockServer::start().await; @@ -155,7 +167,8 @@ async fn server_response_post_child() -> anyhow::Result<()> { Arc::clone(&http_client), HeaderMap::new(), /*auth_provider*/ None, - ), + ) + .with_rejected_token_attribution(), manager, ); let persistor = OAuthPersistor::new( diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index 1006a5bd5d3d..c189b26863ec 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -705,12 +705,6 @@ impl RmcpClient { Ok(response) } - async fn service(&self) -> Result>> { - self.service_and_oauth_persistor() - .await - .map(|(service, _oauth_persistor)| service) - } - async fn service_and_oauth_persistor( &self, ) -> Result<( @@ -992,56 +986,12 @@ impl RmcpClient { // Keep the OAuth persistor paired with the service that performs this operation. Session // recovery can replace both while the request is in flight; rereading only the persistor // after a 401 could refresh credentials owned by a different transport lifecycle. - let (service, oauth_persistor) = self.service_and_oauth_persistor().await?; - let mut result = Self::run_service_operation_with_transient_retries( - Arc::clone(&service), - label, - timeout, - deadline, - self.elicitation_pause_state.clone(), - &operation, - ) - .await; + let (mut service, mut oauth_persistor) = self.service_and_oauth_persistor().await?; + let mut oauth_recovered = false; + let mut session_recovered = false; - if let Some(rejected_access_token) = result - .as_ref() - .err() - .and_then(Self::rejected_access_token_from_operation_error) - && let Some(oauth_persistor) = oauth_persistor - { - // Public request/notification recovery stays here rather than in the transport - // wrapper because this layer owns the caller deadline. RMCP can continue processing a - // queued transport message after the caller times out; retrying it inside the wrapper - // could therefore replay a timed-out tool call. The refresh transaction itself is - // independently owned and completes to its bounded provider timeout if this caller is - // canceled. - let remaining = remaining_operation_timeout(label, timeout, deadline)?; - let refresh = oauth_persistor.refresh_after_unauthorized(rejected_access_token); - let refresh_result = match remaining { - Some(remaining) => match time::timeout(remaining, refresh).await { - Ok(result) => result, - Err(_) => { - // `refresh_after_unauthorized` spawns the credential transaction before it - // waits. Dropping this caller wait therefore detaches the JoinHandle while - // the transaction retains the credential lock and continues through its - // own provider/persistence bounds. The public operation still honors the - // timeout it advertised and does not replay the rejected request later. - return Err(ClientOperationError::Timeout { - label: label.to_string(), - duration: timeout.unwrap_or(remaining), - } - .into()); - } - }, - None => refresh.await, - }; - if let Err(error) = refresh_result { - if let Err(timeout_error) = remaining_operation_timeout(label, timeout, deadline) { - return Err(timeout_error.into()); - } - return Err(error); - } - result = Self::run_service_operation_with_transient_retries( + loop { + let result = Self::run_service_operation_with_transient_retries( Arc::clone(&service), label, timeout, @@ -1050,35 +1000,66 @@ impl RmcpClient { &operation, ) .await; - if result + + if let Some(rejected_access_token) = result .as_ref() .err() .and_then(Self::rejected_access_token_from_operation_error) - .is_some() { - // The rejected token is needed only to attribute the first 401. A second 401 - // after the one allowed refresh means this lifecycle needs reauthentication. - return Err(AuthError::AuthorizationRequired.into()); + if oauth_recovered { + // The rejected token is needed only to attribute the first 401. A second 401 + // after the one allowed refresh means this lifecycle needs reauthentication. + return Err(AuthError::AuthorizationRequired.into()); + } + let Some(oauth_persistor) = oauth_persistor.as_ref() else { + return result.map_err(Into::into); + }; + + // Public request/notification recovery stays here rather than in the transport + // wrapper because this layer owns the caller deadline. RMCP can continue + // processing a queued transport message after the caller times out; retrying it + // inside the wrapper could replay a timed-out tool call. + let remaining = remaining_operation_timeout(label, timeout, deadline)?; + let refresh = oauth_persistor.refresh_after_unauthorized(rejected_access_token); + let refresh_result = match remaining { + Some(remaining) => match time::timeout(remaining, refresh).await { + Ok(result) => result, + Err(_) => { + // The owned transaction keeps running after this caller stops waiting, + // but the rejected operation is not replayed after its deadline. + return Err(ClientOperationError::Timeout { + label: label.to_string(), + duration: timeout.unwrap_or(remaining), + } + .into()); + } + }, + None => refresh.await, + }; + if let Err(error) = refresh_result { + if let Err(timeout_error) = + remaining_operation_timeout(label, timeout, deadline) + { + return Err(timeout_error.into()); + } + return Err(error); + } + oauth_recovered = true; + continue; } - } - if result.as_ref().is_err_and(Self::is_session_expired_404) { - // Session recovery remains one-shot and runs after the optional OAuth retry, so a 401 - // followed by the old session's 404 still reconstructs the transport before retrying. - self.reinitialize_after_session_expiry(&service).await?; - let recovered_service = self.service().await?; - result = Self::run_service_operation_with_transient_retries( - recovered_service, - label, - timeout, - deadline, - self.elicitation_pause_state.clone(), - &operation, - ) - .await; - } + if !session_recovered && result.as_ref().is_err_and(Self::is_session_expired_404) { + // OAuth and session recovery are each one-shot, but either error may arrive first. + // Re-entering this loop lets 404 -> 401 compose just like the existing 401 -> 404 + // path without allowing either recovery to repeat indefinitely. + self.reinitialize_after_session_expiry(&service).await?; + (service, oauth_persistor) = self.service_and_oauth_persistor().await?; + session_recovered = true; + continue; + } - result.map_err(Into::into) + return result.map_err(Into::into); + } } async fn run_service_operation_with_transient_retries( @@ -1313,7 +1294,8 @@ async fn create_oauth_transport_client( }; let auth_client = AuthClient::new( - StreamableHttpClientAdapter::new(http_client, default_headers, /*auth_provider*/ None), + StreamableHttpClientAdapter::new(http_client, default_headers, /*auth_provider*/ None) + .with_rejected_token_attribution(), manager, ); let auth_manager = auth_client.auth_manager.clone(); diff --git a/codex-rs/rmcp-client/src/startup_error.rs b/codex-rs/rmcp-client/src/startup_error.rs index c74f637e3d26..d10c8cd0ed9c 100644 --- a/codex-rs/rmcp-client/src/startup_error.rs +++ b/codex-rs/rmcp-client/src/startup_error.rs @@ -34,7 +34,7 @@ fn client_initialize_error_requires_authentication(error: &ClientInitializeError error, StreamableHttpError::Auth(auth_error) if auth_error_requires_authentication(auth_error) - ) + ) || matches!(error, StreamableHttpError::AuthRequired(_)) }) } diff --git a/codex-rs/rmcp-client/tests/streamable_http_oauth_startup.rs b/codex-rs/rmcp-client/tests/streamable_http_oauth_startup.rs index d8ea63914e07..1e1b6db69071 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_oauth_startup.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_oauth_startup.rs @@ -1,5 +1,8 @@ mod streamable_http_test_support; +use std::sync::Arc; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; use std::time::Duration; use std::time::Instant; use std::time::SystemTime; @@ -199,21 +202,31 @@ async fn recovers_initialization_and_operation_401_once() -> anyhow::Result<()> "authorization", format!("Bearer {FINAL_ACCESS_TOKEN}"), )) - .respond_with(|request: &Request| { - let body: Value = request.body_json().expect("valid JSON-RPC request"); - match body.get("method").and_then(Value::as_str) { - Some("tools/list") => ResponseTemplate::new(200).set_body_json(json!({ - "jsonrpc": "2.0", - "id": body.get("id").cloned().unwrap_or(Value::Null), - "result": { "tools": [] }, - })), - Some("resources/list") => ResponseTemplate::new(401) - .insert_header("www-authenticate", "Bearer realm=\"mcp\""), - method => ResponseTemplate::new(400) - .set_body_string(format!("unexpected JSON-RPC method: {method:?}")), + .respond_with({ + let resource_attempts = Arc::new(AtomicUsize::new(0)); + move |request: &Request| { + let body: Value = request.body_json().expect("valid JSON-RPC request"); + match body.get("method").and_then(Value::as_str) { + Some("tools/list") => ResponseTemplate::new(200).set_body_json(json!({ + "jsonrpc": "2.0", + "id": body.get("id").cloned().unwrap_or(Value::Null), + "result": { "tools": [] }, + })), + Some("resources/list") + if resource_attempts.fetch_add(1, Ordering::SeqCst) == 0 => + { + ResponseTemplate::new(404) + } + Some("resources/list") => ResponseTemplate::new(401) + .insert_header("www-authenticate", "Bearer realm=\"mcp\""), + Some("initialize") => initialize_response(&body), + Some("notifications/initialized") => ResponseTemplate::new(202), + method => ResponseTemplate::new(400) + .set_body_string(format!("unexpected JSON-RPC method: {method:?}")), + } } }) - .expect(2) + .expect(5) .mount(&server) .await; Mock::given(method("POST")) @@ -785,19 +798,21 @@ async fn mount_refresh( } fn initialize_response(body: &Value) -> ResponseTemplate { - ResponseTemplate::new(200).set_body_json(json!({ - "jsonrpc": "2.0", - "id": body.get("id").cloned().unwrap_or(Value::Null), - "result": { - "protocolVersion": body - .pointer("/params/protocolVersion") - .cloned() - .unwrap_or_else(|| json!("2025-06-18")), - "capabilities": {}, - "serverInfo": { - "name": "oauth-401-recovery-test", - "version": "0.0.0-test", + ResponseTemplate::new(200) + .insert_header("mcp-session-id", "oauth-recovery-session") + .set_body_json(json!({ + "jsonrpc": "2.0", + "id": body.get("id").cloned().unwrap_or(Value::Null), + "result": { + "protocolVersion": body + .pointer("/params/protocolVersion") + .cloned() + .unwrap_or_else(|| json!("2025-06-18")), + "capabilities": {}, + "serverInfo": { + "name": "oauth-401-recovery-test", + "version": "0.0.0-test", + }, }, - }, - })) + })) } diff --git a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs index 29d2404ede9e..a53d0cb08109 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs @@ -11,6 +11,8 @@ use codex_exec_server::HttpClient; use codex_exec_server::HttpRequestParams; use codex_exec_server::HttpRequestResponse; use codex_exec_server::HttpResponseBodyStream; +use codex_rmcp_client::RmcpClient; +use codex_rmcp_client::is_authentication_required_error; use futures::FutureExt as _; use futures::future::BoxFuture; use pretty_assertions::assert_eq; @@ -124,7 +126,13 @@ async fn streamable_http_initialize_retries_remote_no_response_error() -> anyhow async fn streamable_http_initialize_retries_transient_http_status() -> anyhow::Result<()> { let (_server, base_url) = spawn_streamable_http_server().await?; - arm_initialize_post_failure(&base_url, /*status*/ 502, /*remaining*/ 1).await?; + arm_initialize_post_failure( + &base_url, + /*status*/ 502, + /*remaining*/ 1, + /*www_authenticate_headers*/ &[], + ) + .await?; let client = create_client(&base_url).await?; let result = call_echo_tool(&client, "after-status-retry").await?; @@ -301,6 +309,37 @@ async fn streamable_http_401_does_not_trigger_recovery() -> anyhow::Result<()> { Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn static_bearer_initialize_401_requires_authentication() -> anyhow::Result<()> { + let (_server, base_url) = spawn_streamable_http_server().await?; + arm_initialize_post_failure( + &base_url, + /*status*/ 401, + /*remaining*/ 1, + /*www_authenticate_headers*/ &[r#"Bearer realm="mcp""#], + ) + .await?; + + let client = RmcpClient::new_streamable_http_client( + "test-static-bearer-401", + &format!("{base_url}/mcp"), + Some("test-bearer".to_string()), + /*http_headers*/ None, + /*env_http_headers*/ None, + codex_config::types::OAuthCredentialsStoreMode::File, + codex_config::types::AuthKeyringBackendKind::default(), + Environment::default_for_tests().get_http_client(), + /*auth_provider*/ None, + ) + .await?; + let error = streamable_http_test_support::initialize_client(&client) + .await + .expect_err("a challenged static bearer token should require authentication"); + + assert!(is_authentication_required_error(&error)); + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn streamable_http_403_scope_challenge_returns_insufficient_scope() -> anyhow::Result<()> { let (_server, base_url) = spawn_streamable_http_server().await?; diff --git a/codex-rs/rmcp-client/tests/streamable_http_test_support.rs b/codex-rs/rmcp-client/tests/streamable_http_test_support.rs index 87a52f55dff6..30d3039f5bb7 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_test_support.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_test_support.rs @@ -258,12 +258,14 @@ pub(crate) async fn arm_initialize_post_failure( base_url: &str, status: u16, remaining: usize, + www_authenticate_headers: &[&str], ) -> anyhow::Result<()> { let response = reqwest::Client::new() .post(format!("{base_url}{INITIALIZE_POST_FAILURE_CONTROL_PATH}")) .json(&json!({ "status": status, "remaining": remaining, + "www_authenticate_headers": www_authenticate_headers, })) .send() .await?;