diff --git a/codex-rs/config/src/types.rs b/codex-rs/config/src/types.rs index 5303fccd43b6..ccb647a7bf16 100644 --- a/codex-rs/config/src/types.rs +++ b/codex-rs/config/src/types.rs @@ -103,7 +103,10 @@ pub enum AuthCredentialsStoreMode { #[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "lowercase")] pub enum OAuthCredentialsStoreMode { - /// `Keyring` when available; otherwise, `File`. + /// Prefer `Keyring` and use `File` when keyring storage is unavailable. + /// Once an MCP client loads credentials from one store, that client keeps the resolved store + /// for its lifetime: refresh reads and writes surface store failures instead of switching to + /// the other store and risking replay of an older rotating refresh token. /// Credentials stored in the keyring will only be readable by Codex unless the user explicitly grants access via OS-level keyring access. #[default] Auto, diff --git a/codex-rs/core/config.schema.json b/codex-rs/core/config.schema.json index 92a45d1978b5..6c3e3a3ed6bd 100644 --- a/codex-rs/core/config.schema.json +++ b/codex-rs/core/config.schema.json @@ -2098,7 +2098,7 @@ "description": "Determine where Codex should store and read MCP credentials.", "oneOf": [ { - "description": "`Keyring` when available; otherwise, `File`. Credentials stored in the keyring will only be readable by Codex unless the user explicitly grants access via OS-level keyring access.", + "description": "Prefer `Keyring` and use `File` when keyring storage is unavailable. Once an MCP client loads credentials from one store, that client keeps the resolved store for its lifetime: refresh reads and writes surface store failures instead of switching to the other store and risking replay of an older rotating refresh token. Credentials stored in the keyring will only be readable by Codex unless the user explicitly grants access via OS-level keyring access.", "enum": [ "auto" ], @@ -5497,4 +5497,4 @@ }, "title": "ConfigToml", "type": "object" -} \ No newline at end of file +} diff --git a/codex-rs/rmcp-client/src/lib.rs b/codex-rs/rmcp-client/src/lib.rs index e1041fdafbbf..5aae4493ce34 100644 --- a/codex-rs/rmcp-client/src/lib.rs +++ b/codex-rs/rmcp-client/src/lib.rs @@ -23,7 +23,6 @@ pub use in_process_transport::InProcessTransportFactory; pub use oauth::StoredOAuthTokens; pub use oauth::WrappedOAuthTokenResponse; pub use oauth::delete_oauth_tokens; -pub(crate) use oauth::load_oauth_tokens; pub use oauth::save_oauth_tokens; pub use perform_oauth_login::OAuthProviderError; pub use perform_oauth_login::OauthLoginHandle; diff --git a/codex-rs/rmcp-client/src/oauth.rs b/codex-rs/rmcp-client/src/oauth.rs index d0005d20a6d0..7ca5b28fd28b 100644 --- a/codex-rs/rmcp-client/src/oauth.rs +++ b/codex-rs/rmcp-client/src/oauth.rs @@ -16,6 +16,10 @@ //! //! If the keyring is not available or fails, we fall back to CODEX_HOME/.credentials.json which is consistent with other coding CLI agents. +mod persistor; +mod refresh_lock; +mod resolved_store; + use anyhow::Context; use anyhow::Error; use anyhow::Result; @@ -51,11 +55,20 @@ use tracing::warn; use codex_keyring_store::DefaultKeyringStore; use codex_keyring_store::KeyringStore; -use rmcp::transport::auth::AuthorizationManager; -use tokio::sync::Mutex; - use codex_utils_home_dir::find_codex_home; +pub(crate) use self::persistor::OAuthPersistor; +pub(crate) use self::resolved_store::LoadedOAuthTokens; +pub(crate) use self::resolved_store::ResolvedOAuthCredentialStore; +#[cfg(test)] +use self::resolved_store::load_oauth_tokens_from_keyring_with_fallback_to_file; +pub(crate) use self::resolved_store::load_oauth_tokens_from_resolved_store; +pub(crate) use self::resolved_store::load_oauth_tokens_with_source; +#[cfg(test)] +use self::resolved_store::load_oauth_tokens_with_source_and_keyring_store; +#[cfg(test)] +use rmcp::transport::auth::AuthorizationManager; + const KEYRING_SERVICE: &str = "Codex MCP Credentials"; const MCP_OAUTH_SECRET_PREFIX: &str = "MCP_OAUTH"; const REFRESH_SKEW_MILLIS: u64 = 30_000; @@ -96,20 +109,10 @@ pub(crate) fn load_oauth_tokens( store_mode: OAuthCredentialsStoreMode, keyring_backend_kind: AuthKeyringBackendKind, ) -> Result> { - let keyring_store = DefaultKeyringStore; - match store_mode { - OAuthCredentialsStoreMode::Auto => load_oauth_tokens_from_keyring_with_fallback_to_file( - &keyring_store, - keyring_backend_kind, - server_name, - url, - ), - OAuthCredentialsStoreMode::File => load_oauth_tokens_from_file(server_name, url), - OAuthCredentialsStoreMode::Keyring => { - load_oauth_tokens_from_keyring(&keyring_store, keyring_backend_kind, server_name, url) - .with_context(|| "failed to read OAuth tokens from keyring".to_string()) - } - } + Ok( + load_oauth_tokens_with_source(server_name, url, store_mode, keyring_backend_kind)? + .map(|loaded| loaded.tokens), + ) } pub(crate) fn oauth_token_status( @@ -164,23 +167,6 @@ fn refresh_expires_in_from_timestamp(tokens: &mut StoredOAuthTokens) { } } -fn load_oauth_tokens_from_keyring_with_fallback_to_file( - keyring_store: &K, - keyring_backend_kind: AuthKeyringBackendKind, - server_name: &str, - url: &str, -) -> Result> { - match load_oauth_tokens_from_keyring(keyring_store, keyring_backend_kind, server_name, url) { - Ok(Some(tokens)) => Ok(Some(tokens)), - Ok(None) => load_oauth_tokens_from_file(server_name, url), - Err(error) => { - warn!("failed to read OAuth tokens from keyring: {error}"); - load_oauth_tokens_from_file(server_name, url) - .with_context(|| format!("failed to read OAuth tokens from keyring: {error}")) - } - } -} - fn load_oauth_tokens_from_keyring( keyring_store: &K, keyring_backend_kind: AuthKeyringBackendKind, @@ -249,16 +235,32 @@ pub fn save_oauth_tokens( keyring_backend_kind: AuthKeyringBackendKind, ) -> Result<()> { let keyring_store = DefaultKeyringStore; + save_oauth_tokens_with_keyring_store( + &keyring_store, + server_name, + tokens, + store_mode, + keyring_backend_kind, + ) +} + +fn save_oauth_tokens_with_keyring_store( + keyring_store: &K, + server_name: &str, + tokens: &StoredOAuthTokens, + store_mode: OAuthCredentialsStoreMode, + keyring_backend_kind: AuthKeyringBackendKind, +) -> Result<()> { match store_mode { OAuthCredentialsStoreMode::Auto => save_oauth_tokens_with_keyring_with_fallback_to_file( - &keyring_store, + keyring_store, keyring_backend_kind, server_name, tokens, ), OAuthCredentialsStoreMode::File => save_oauth_tokens_to_file(tokens), - OAuthCredentialsStoreMode::Keyring => save_oauth_tokens_with_keyring( - &keyring_store, + OAuthCredentialsStoreMode::Keyring => save_oauth_tokens_with_keyring_and_cleanup_file( + keyring_store, keyring_backend_kind, server_name, tokens, @@ -282,6 +284,24 @@ fn save_oauth_tokens_with_keyring( } } +/// Saves to the selected keyring backend, then best-effort removes the fallback file entry. +/// +/// A cleanup failure does not change the current client's selected authority, but it can leave +/// legacy residue that a different `Auto` process may discover if keyring availability changes. +fn save_oauth_tokens_with_keyring_and_cleanup_file( + keyring_store: &K, + keyring_backend_kind: AuthKeyringBackendKind, + server_name: &str, + tokens: &StoredOAuthTokens, +) -> Result<()> { + save_oauth_tokens_with_keyring(keyring_store, keyring_backend_kind, server_name, tokens)?; + let key = compute_store_key(server_name, &tokens.url)?; + if let Err(error) = delete_oauth_tokens_from_file(&key) { + warn!("failed to remove OAuth tokens from fallback storage: {error:?}"); + } + Ok(()) +} + fn save_oauth_tokens_to_direct_keyring( keyring_store: &K, server_name: &str, @@ -291,12 +311,7 @@ fn save_oauth_tokens_to_direct_keyring( let key = compute_store_key(server_name, &tokens.url)?; match keyring_store.save(KEYRING_SERVICE, &key, &serialized) { - Ok(()) => { - if let Err(error) = delete_oauth_tokens_from_file(&key) { - warn!("failed to remove OAuth tokens from fallback storage: {error:?}"); - } - Ok(()) - } + Ok(()) => Ok(()), Err(error) => { let message = format!( "failed to write OAuth tokens to keyring: {}", @@ -324,13 +339,7 @@ fn save_oauth_tokens_to_secrets_keyring( let secret_name = compute_secret_name(server_name, &tokens.url)?; manager .set(&SecretScope::Global, &secret_name, &serialized) - .context("failed to write OAuth tokens to encrypted storage")?; - - let key = compute_store_key(server_name, &tokens.url)?; - if let Err(error) = delete_oauth_tokens_from_file(&key) { - warn!("failed to remove OAuth tokens from fallback storage: {error:?}"); - } - Ok(()) + .context("failed to write OAuth tokens to encrypted storage") } fn save_oauth_tokens_with_keyring_with_fallback_to_file( @@ -339,7 +348,12 @@ fn save_oauth_tokens_with_keyring_with_fallback_to_file Result<()> { - match save_oauth_tokens_with_keyring(keyring_store, keyring_backend_kind, server_name, tokens) { + match save_oauth_tokens_with_keyring_and_cleanup_file( + keyring_store, + keyring_backend_kind, + server_name, + tokens, + ) { Ok(()) => Ok(()), Err(error) => { let message = error.to_string(); @@ -444,134 +458,6 @@ fn delete_oauth_tokens_from_secrets_keyring( Ok(secrets_removed) } -#[derive(Clone)] -pub(crate) struct OAuthPersistor { - inner: Arc, -} - -struct OAuthPersistorInner { - server_name: String, - url: String, - authorization_manager: Arc>, - store_mode: OAuthCredentialsStoreMode, - keyring_backend_kind: AuthKeyringBackendKind, - last_credentials: Mutex>, -} - -impl OAuthPersistor { - pub(crate) fn new( - server_name: String, - url: String, - authorization_manager: Arc>, - store_mode: OAuthCredentialsStoreMode, - keyring_backend_kind: AuthKeyringBackendKind, - initial_credentials: Option, - ) -> Self { - Self { - inner: Arc::new(OAuthPersistorInner { - server_name, - url, - authorization_manager, - store_mode, - keyring_backend_kind, - last_credentials: Mutex::new(initial_credentials), - }), - } - } - - /// Persists the latest stored credentials if they have changed. - /// Deletes the credentials if they are no longer present. - #[expect( - clippy::await_holding_invalid_type, - reason = "AuthorizationManager async access must be serialized through its mutex" - )] - pub(crate) async fn persist_if_needed(&self) -> Result<()> { - let (client_id, maybe_credentials) = { - let manager = self.inner.authorization_manager.clone(); - let guard = manager.lock().await; - guard.get_credentials().await - }?; - - match maybe_credentials { - Some(credentials) => { - let mut last_credentials = self.inner.last_credentials.lock().await; - let new_token_response = WrappedOAuthTokenResponse(credentials.clone()); - let same_token = last_credentials - .as_ref() - .map(|prev| prev.token_response == new_token_response) - .unwrap_or(false); - let expires_at = if same_token { - last_credentials.as_ref().and_then(|prev| prev.expires_at) - } else { - compute_expires_at_millis(&credentials) - }; - let stored = StoredOAuthTokens { - server_name: self.inner.server_name.clone(), - url: self.inner.url.clone(), - client_id, - token_response: new_token_response, - expires_at, - }; - if last_credentials.as_ref() != Some(&stored) { - save_oauth_tokens( - &self.inner.server_name, - &stored, - self.inner.store_mode, - self.inner.keyring_backend_kind, - )?; - *last_credentials = Some(stored); - } - } - None => { - let mut last_serialized = self.inner.last_credentials.lock().await; - if last_serialized.take().is_some() - && let Err(error) = delete_oauth_tokens( - &self.inner.server_name, - &self.inner.url, - self.inner.store_mode, - self.inner.keyring_backend_kind, - ) - { - warn!( - "failed to remove OAuth tokens for server {}: {error}", - self.inner.server_name - ); - } - } - } - - Ok(()) - } - - #[expect( - clippy::await_holding_invalid_type, - reason = "AuthorizationManager async access must be serialized through its mutex" - )] - pub(crate) async fn refresh_if_needed(&self) -> Result<()> { - let expires_at = { - let guard = self.inner.last_credentials.lock().await; - guard.as_ref().and_then(|tokens| tokens.expires_at) - }; - - if !token_needs_refresh(expires_at) { - return Ok(()); - } - - { - let manager = self.inner.authorization_manager.clone(); - let guard = manager.lock().await; - guard.refresh_token().await.with_context(|| { - format!( - "failed to refresh OAuth tokens for server {}", - self.inner.server_name - ) - })?; - } - - self.persist_if_needed().await - } -} - const FALLBACK_FILENAME: &str = ".credentials.json"; const MCP_SERVER_TYPE: &str = "http"; @@ -812,17 +698,30 @@ fn sha_256_prefix(value: &Value) -> Result { #[cfg(test)] mod tests { + use super::refresh_lock::RefreshCredentialLock; use super::*; use anyhow::Result; use codex_secrets::compute_keyring_account; use keyring::Error as KeyringError; use pretty_assertions::assert_eq; + use rmcp::transport::auth::OAuthState; + use serde_json::json; use std::sync::Arc; use std::sync::Mutex; use std::sync::MutexGuard; use std::sync::OnceLock; use std::sync::PoisonError; + use std::sync::atomic::AtomicUsize; + use std::sync::atomic::Ordering; + use std::sync::mpsc; use tempfile::tempdir; + use tokio::sync::Mutex as TokioMutex; + use wiremock::Mock; + use wiremock::MockServer; + use wiremock::ResponseTemplate; + use wiremock::matchers::body_string_contains; + use wiremock::matchers::method; + use wiremock::matchers::path; use codex_keyring_store::tests::MockKeyringStore; @@ -898,7 +797,8 @@ mod tests { &tokens.url, )? .expect("tokens should load from fallback"); - assert_tokens_match_without_expiry(&loaded, &expected); + assert_eq!(loaded.store, ResolvedOAuthCredentialStore::File); + assert_tokens_match_without_expiry(&loaded.tokens, &expected); Ok(()) } @@ -920,7 +820,43 @@ mod tests { &tokens.url, )? .expect("tokens should load from fallback"); - assert_tokens_match_without_expiry(&loaded, &expected); + assert_eq!(loaded.store, ResolvedOAuthCredentialStore::File); + assert_tokens_match_without_expiry(&loaded.tokens, &expected); + Ok(()) + } + + #[test] + fn auto_resolution_prioritizes_keyring_and_tracks_its_source() -> Result<()> { + let _env = TempCodexHome::new(); + let store = MockKeyringStore::default(); + let keyring_tokens = sample_tokens(); + let mut file_tokens = sample_tokens(); + file_tokens + .token_response + .0 + .set_access_token(AccessToken::new("file-access-token".to_string())); + super::save_oauth_tokens_to_file(&file_tokens)?; + super::save_oauth_tokens_with_keyring( + &store, + AuthKeyringBackendKind::Direct, + &keyring_tokens.server_name, + &keyring_tokens, + )?; + + let loaded = super::load_oauth_tokens_with_source_and_keyring_store( + &store, + &keyring_tokens.server_name, + &keyring_tokens.url, + OAuthCredentialsStoreMode::Auto, + AuthKeyringBackendKind::Direct, + )? + .expect("Auto should load keyring credentials"); + + assert_eq!( + loaded.store, + ResolvedOAuthCredentialStore::Keyring(AuthKeyringBackendKind::Direct) + ); + assert_tokens_match_without_expiry(&loaded.tokens, &keyring_tokens); Ok(()) } @@ -1058,6 +994,422 @@ mod tests { Ok(()) } + #[tokio::test] + async fn refresh_transaction_preserves_credentials_when_resolved_keyring_reread_fails() + -> Result<()> { + let _env = TempCodexHome::new(); + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/.well-known/oauth-authorization-server/mcp")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "authorization_endpoint": format!("{}/oauth/authorize", server.uri()), + "token_endpoint": format!("{}/oauth/token", server.uri()), + "scopes_supported": ["scope-a", "scope-b"], + }))) + .mount(&server) + .await; + let store = MockKeyringStore::default(); + let initial_tokens = expired_sample_tokens(&format!("{}/mcp", server.uri())); + let key = super::compute_store_key(&initial_tokens.server_name, &initial_tokens.url)?; + store.set_error(&key, KeyringError::Invalid("error".into(), "load".into())); + + let manager = authorization_manager_for(&initial_tokens).await?; + let persistor = OAuthPersistor::new( + initial_tokens.server_name.clone(), + initial_tokens.url.clone(), + manager.clone(), + ResolvedOAuthCredentialStore::Keyring(AuthKeyringBackendKind::Direct), + Some(initial_tokens.clone()), + ); + + let error = persistor + .refresh_if_needed_with_keyring_store(&store) + .await + .expect_err("keyring reread failure should abort refresh"); + + assert!( + error + .to_string() + .contains("failed to reread OAuth tokens from resolved keyring storage"), + "unexpected error: {error:#}" + ); + let manager_tokens = tokens_from_manager(&manager).await?; + assert_eq!(manager_tokens.token_response, initial_tokens.token_response); + Ok(()) + } + + #[tokio::test] + async fn resolved_keyring_write_failure_never_falls_back_to_file() -> Result<()> { + let _env = TempCodexHome::new(); + let server = MockServer::start().await; + mount_oauth_metadata(&server).await; + let store = MockKeyringStore::default(); + let mut initial_tokens = sample_tokens(); + initial_tokens.url = format!("{}/mcp", server.uri()); + let mut updated_tokens = initial_tokens.clone(); + updated_tokens + .token_response + .0 + .set_access_token(AccessToken::new("updated-access-token".to_string())); + + let manager = authorization_manager_for(&updated_tokens).await?; + let persistor = OAuthPersistor::new( + initial_tokens.server_name.clone(), + initial_tokens.url.clone(), + manager, + ResolvedOAuthCredentialStore::Keyring(AuthKeyringBackendKind::Direct), + Some(initial_tokens), + ); + let key = super::compute_store_key(&updated_tokens.server_name, &updated_tokens.url)?; + store.set_error(&key, KeyringError::Invalid("error".into(), "save".into())); + + let error = persistor + .persist_if_needed_with_keyring_store(&store) + .await + .expect_err("resolved keyring write should fail instead of falling back"); + + assert!( + error + .to_string() + .contains("failed to write OAuth tokens to keyring"), + "unexpected error: {error:#}" + ); + assert!(!super::fallback_file_path()?.exists()); + Ok(()) + } + + #[tokio::test] + async fn refresh_transaction_adopts_valid_reread_without_provider_refresh() -> Result<()> { + let _env = TempCodexHome::new(); + let server = MockServer::start().await; + mount_oauth_metadata(&server).await; + let store = MockKeyringStore::default(); + let mut initial_tokens = expired_sample_tokens(&format!("{}/mcp", server.uri())); + initial_tokens + .token_response + .0 + .set_refresh_token(Some(RefreshToken::new("stale-refresh-token".to_string()))); + + let mut latest_tokens = sample_tokens(); + latest_tokens.url.clone_from(&initial_tokens.url); + latest_tokens + .token_response + .0 + .set_access_token(AccessToken::new( + "already-refreshed-access-token".to_string(), + )); + latest_tokens + .token_response + .0 + .set_refresh_token(Some(RefreshToken::new( + "already-rotated-refresh-token".to_string(), + ))); + + super::save_oauth_tokens_with_keyring_store( + &store, + &latest_tokens.server_name, + &latest_tokens, + OAuthCredentialsStoreMode::Keyring, + AuthKeyringBackendKind::Direct, + )?; + + let manager = authorization_manager_for(&initial_tokens).await?; + let persistor = OAuthPersistor::new( + initial_tokens.server_name.clone(), + initial_tokens.url.clone(), + manager.clone(), + ResolvedOAuthCredentialStore::Keyring(AuthKeyringBackendKind::Direct), + Some(initial_tokens), + ); + + persistor + .refresh_if_needed_with_keyring_store(&store) + .await?; + + let manager_tokens = tokens_from_manager(&manager).await?; + assert_eq!( + access_token(&manager_tokens), + "already-refreshed-access-token" + ); + assert_eq!( + refresh_token(&manager_tokens), + Some("already-rotated-refresh-token".to_string()) + ); + Ok(()) + } + + #[tokio::test] + async fn refresh_transaction_refreshes_when_only_derived_expires_in_drifted() -> Result<()> { + let _env = TempCodexHome::new(); + let server = MockServer::start().await; + mount_oauth_metadata(&server).await; + mount_refresh_response( + &server, + "refresh-token", + "refreshed-after-expiry-drift", + "rotated-after-expiry-drift", + ) + .await; + + let store = MockKeyringStore::default(); + let mut initial_tokens = sample_tokens(); + initial_tokens.url = format!("{}/mcp", server.uri()); + initial_tokens.expires_at = Some(now_millis().saturating_add(5_000)); + initial_tokens + .token_response + .0 + .set_expires_in(Some(&Duration::from_secs(3600))); + super::save_oauth_tokens_with_keyring_store( + &store, + &initial_tokens.server_name, + &initial_tokens, + OAuthCredentialsStoreMode::Keyring, + AuthKeyringBackendKind::Direct, + )?; + + let manager = authorization_manager_for(&initial_tokens).await?; + let persistor = OAuthPersistor::new( + initial_tokens.server_name.clone(), + initial_tokens.url.clone(), + manager, + ResolvedOAuthCredentialStore::Keyring(AuthKeyringBackendKind::Direct), + Some(initial_tokens.clone()), + ); + + persistor + .refresh_if_needed_with_keyring_store(&store) + .await?; + + server.verify().await; + let stored = super::load_oauth_tokens_with_source_and_keyring_store( + &store, + &initial_tokens.server_name, + &initial_tokens.url, + OAuthCredentialsStoreMode::Keyring, + AuthKeyringBackendKind::Direct, + )? + .expect("refreshed tokens should be persisted"); + assert_eq!(access_token(&stored.tokens), "refreshed-after-expiry-drift"); + assert_eq!( + refresh_token(&stored.tokens), + Some("rotated-after-expiry-drift".to_string()) + ); + Ok(()) + } + + #[tokio::test] + async fn refresh_transaction_uses_latest_refresh_token_when_reread_is_expired() -> Result<()> { + let _env = TempCodexHome::new(); + let server = MockServer::start().await; + mount_oauth_metadata(&server).await; + mount_refresh_response( + &server, + "latest-refresh-token", + "refreshed-from-latest-token", + "rotated-from-latest-token", + ) + .await; + + let store = MockKeyringStore::default(); + let mut initial_tokens = expired_sample_tokens(&format!("{}/mcp", server.uri())); + initial_tokens + .token_response + .0 + .set_refresh_token(Some(RefreshToken::new("stale-refresh-token".to_string()))); + + let mut latest_tokens = initial_tokens.clone(); + latest_tokens + .token_response + .0 + .set_access_token(AccessToken::new("latest-expired-access-token".to_string())); + latest_tokens + .token_response + .0 + .set_refresh_token(Some(RefreshToken::new("latest-refresh-token".to_string()))); + super::save_oauth_tokens_with_keyring_store( + &store, + &latest_tokens.server_name, + &latest_tokens, + OAuthCredentialsStoreMode::Keyring, + AuthKeyringBackendKind::Direct, + )?; + + let manager = authorization_manager_for(&initial_tokens).await?; + let persistor = OAuthPersistor::new( + initial_tokens.server_name.clone(), + initial_tokens.url.clone(), + manager, + ResolvedOAuthCredentialStore::Keyring(AuthKeyringBackendKind::Direct), + Some(initial_tokens.clone()), + ); + + persistor + .refresh_if_needed_with_keyring_store(&store) + .await?; + + server.verify().await; + let stored = super::load_oauth_tokens_with_source_and_keyring_store( + &store, + &initial_tokens.server_name, + &initial_tokens.url, + OAuthCredentialsStoreMode::Keyring, + AuthKeyringBackendKind::Direct, + )? + .expect("refreshed tokens should be persisted"); + assert_eq!(access_token(&stored.tokens), "refreshed-from-latest-token"); + assert_eq!( + refresh_token(&stored.tokens), + Some("rotated-from-latest-token".to_string()) + ); + Ok(()) + } + + #[tokio::test] + async fn provider_refresh_timeout_permits_a_later_serialized_retry() -> Result<()> { + let _env = TempCodexHome::new(); + let server = MockServer::start().await; + mount_oauth_metadata(&server).await; + let request_count = Arc::new(AtomicUsize::new(/*v*/ 0)); + let request_count_for_response = Arc::clone(&request_count); + Mock::given(method("POST")) + .and(path("/oauth/token")) + .and(body_string_contains("grant_type=refresh_token")) + .and(body_string_contains("refresh_token=refresh-token")) + .respond_with(move |_request: &wiremock::Request| { + let response = ResponseTemplate::new(200).set_body_json(json!({ + "access_token": "retried-access-token", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "retried-refresh-token", + "scope": "scope-a scope-b", + })); + if request_count_for_response.fetch_add(1, Ordering::SeqCst) == 0 { + response.set_delay(Duration::from_millis(500)) + } else { + response + } + }) + .expect(2) + .mount(&server) + .await; + + let store = MockKeyringStore::default(); + let initial_tokens = expired_sample_tokens(&format!("{}/mcp", server.uri())); + super::save_oauth_tokens_with_keyring_store( + &store, + &initial_tokens.server_name, + &initial_tokens, + OAuthCredentialsStoreMode::Keyring, + AuthKeyringBackendKind::Direct, + )?; + let manager = authorization_manager_for(&initial_tokens).await?; + let persistor = OAuthPersistor::new( + initial_tokens.server_name.clone(), + initial_tokens.url.clone(), + manager, + ResolvedOAuthCredentialStore::Keyring(AuthKeyringBackendKind::Direct), + Some(initial_tokens.clone()), + ); + + let first_error = persistor + .refresh_if_needed_with_keyring_store_and_timeout(&store, Duration::from_millis(100)) + .await + .expect_err("the first provider request should reach its explicit timeout"); + assert!(first_error.to_string().contains("timed out after 100ms")); + + persistor + .refresh_if_needed_with_keyring_store_and_timeout(&store, Duration::from_secs(1)) + .await?; + + let stored = super::load_oauth_tokens_with_source_and_keyring_store( + &store, + &initial_tokens.server_name, + &initial_tokens.url, + OAuthCredentialsStoreMode::Keyring, + AuthKeyringBackendKind::Direct, + )? + .expect("the later retry should persist the rotated credentials"); + assert_eq!(access_token(&stored.tokens), "retried-access-token"); + assert_eq!( + refresh_token(&stored.tokens), + Some("retried-refresh-token".to_string()) + ); + assert_eq!(request_count.load(Ordering::SeqCst), 2); + server.verify().await; + Ok(()) + } + + #[tokio::test] + async fn caller_cancellation_does_not_cancel_refresh_and_persistence() -> Result<()> { + let _env = TempCodexHome::new(); + let server = MockServer::start().await; + mount_oauth_metadata(&server).await; + let refresh_started = mount_refresh_response_with_signal( + &server, + "refresh-token", + "cancel-safe-access-token", + "cancel-safe-refresh-token", + Duration::from_millis(300), + ) + .await; + + let store = MockKeyringStore::default(); + let initial_tokens = expired_sample_tokens(&format!("{}/mcp", server.uri())); + super::save_oauth_tokens_with_keyring_store( + &store, + &initial_tokens.server_name, + &initial_tokens, + OAuthCredentialsStoreMode::Keyring, + AuthKeyringBackendKind::Direct, + )?; + let manager = authorization_manager_for(&initial_tokens).await?; + let persistor = OAuthPersistor::new( + initial_tokens.server_name.clone(), + initial_tokens.url.clone(), + manager, + ResolvedOAuthCredentialStore::Keyring(AuthKeyringBackendKind::Direct), + Some(initial_tokens.clone()), + ); + let caller = tokio::spawn({ + let persistor = persistor.clone(); + let store = store.clone(); + async move { persistor.refresh_if_needed_with_keyring_store(&store).await } + }); + + wait_for_signal(refresh_started).await?; + caller.abort(); + let caller_error = caller + .await + .expect_err("the caller task should observe cancellation"); + assert!(caller_error.is_cancelled()); + + // The provider handler fires only after the owned transaction has acquired this lock. + // Reacquiring it therefore waits deterministically for refresh and persistence to finish, + // without relying on a scheduler-sensitive sleep after aborting the caller. + let store_key = super::compute_store_key(&initial_tokens.server_name, &initial_tokens.url)?; + let _lock = tokio::time::timeout( + Duration::from_secs(/*secs*/ 2), + RefreshCredentialLock::acquire(&store_key), + ) + .await + .expect("the independently owned refresh should release its credential lock")?; + let stored = super::load_oauth_tokens_with_source_and_keyring_store( + &store, + &initial_tokens.server_name, + &initial_tokens.url, + OAuthCredentialsStoreMode::Keyring, + AuthKeyringBackendKind::Direct, + )? + .expect("the independently owned refresh should still persist credentials"); + assert_eq!(access_token(&stored.tokens), "cancel-safe-access-token"); + assert_eq!( + refresh_token(&stored.tokens), + Some("cancel-safe-refresh-token".to_string()) + ); + server.verify().await; + Ok(()) + } + #[test] fn save_oauth_tokens_with_secrets_backend_falls_back_to_file_when_keyring_fails() -> Result<()> { @@ -1344,6 +1696,152 @@ mod tests { ); } + async fn authorization_manager_for( + tokens: &StoredOAuthTokens, + ) -> Result>> { + let mut state = OAuthState::new(tokens.url.clone(), Some(reqwest::Client::new())).await?; + state + .set_credentials(&tokens.client_id, tokens.token_response.0.clone()) + .await?; + let manager = match state { + OAuthState::Authorized(manager) | OAuthState::Unauthorized(manager) => manager, + OAuthState::Session(_) | OAuthState::AuthorizedHttpClient(_) => { + anyhow::bail!("unexpected OAuth state") + } + _ => anyhow::bail!("unexpected OAuth state"), + }; + Ok(Arc::new(TokioMutex::new(manager))) + } + + async fn mount_oauth_metadata(server: &MockServer) { + Mock::given(method("GET")) + .and(path("/.well-known/oauth-authorization-server/mcp")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "authorization_endpoint": format!("{}/oauth/authorize", server.uri()), + "token_endpoint": format!("{}/oauth/token", server.uri()), + "scopes_supported": ["scope-a", "scope-b"], + }))) + .mount(server) + .await; + } + + async fn mount_refresh_response( + server: &MockServer, + request_refresh_token: &str, + response_access_token: &str, + response_refresh_token: &str, + ) { + Mock::given(method("POST")) + .and(path("/oauth/token")) + .and(body_string_contains("grant_type=refresh_token")) + .and(body_string_contains(format!( + "refresh_token={request_refresh_token}" + ))) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "access_token": response_access_token, + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": response_refresh_token, + "scope": "scope-a scope-b", + }))) + .expect(1) + .mount(server) + .await; + } + + async fn mount_refresh_response_with_signal( + server: &MockServer, + request_refresh_token: &str, + response_access_token: &str, + response_refresh_token: &str, + response_delay: Duration, + ) -> mpsc::Receiver<()> { + let (tx, rx) = mpsc::channel(); + let response_access_token = response_access_token.to_string(); + let response_refresh_token = response_refresh_token.to_string(); + Mock::given(method("POST")) + .and(path("/oauth/token")) + .and(body_string_contains("grant_type=refresh_token")) + .and(body_string_contains(format!( + "refresh_token={request_refresh_token}" + ))) + .respond_with(move |_request: &wiremock::Request| { + let _ = tx.send(()); + let access_token = response_access_token.clone(); + let refresh_token = response_refresh_token.clone(); + ResponseTemplate::new(200) + .set_delay(response_delay) + .set_body_json(json!({ + "access_token": access_token, + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": refresh_token, + "scope": "scope-a scope-b", + })) + }) + .expect(1) + .mount(server) + .await; + rx + } + + async fn wait_for_signal(rx: mpsc::Receiver<()>) -> Result<()> { + tokio::task::spawn_blocking(move || { + rx.recv_timeout(Duration::from_secs(5)) + .context("timed out waiting for refresh request") + }) + .await? + } + + #[expect( + clippy::await_holding_invalid_type, + reason = "AuthorizationManager async access must be serialized through its mutex" + )] + async fn tokens_from_manager( + manager: &Arc>, + ) -> Result { + let guard = manager.lock().await; + let (client_id, token_response) = guard.get_credentials().await?; + let token_response = token_response.expect("manager should have token response"); + Ok(StoredOAuthTokens { + server_name: "test-server".to_string(), + url: "https://example.test".to_string(), + client_id, + token_response: WrappedOAuthTokenResponse(token_response), + expires_at: None, + }) + } + + fn access_token(tokens: &StoredOAuthTokens) -> &str { + tokens.token_response.0.access_token().secret() + } + + fn refresh_token(tokens: &StoredOAuthTokens) -> Option { + tokens + .token_response + .0 + .refresh_token() + .map(|token| token.secret().to_string()) + } + + fn expired_sample_tokens(url: &str) -> StoredOAuthTokens { + let mut tokens = sample_tokens(); + tokens.url = url.to_string(); + tokens.expires_at = Some(0); + tokens + .token_response + .0 + .set_expires_in(Some(&Duration::ZERO)); + tokens + } + + fn now_millis() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_else(|_| Duration::from_secs(0)) + .as_millis() as u64 + } + fn sample_tokens() -> StoredOAuthTokens { let mut response = OAuthTokenResponse::new( AccessToken::new("access-token".to_string()), diff --git a/codex-rs/rmcp-client/src/oauth/persistor.rs b/codex-rs/rmcp-client/src/oauth/persistor.rs new file mode 100644 index 000000000000..9a740885bbe1 --- /dev/null +++ b/codex-rs/rmcp-client/src/oauth/persistor.rs @@ -0,0 +1,406 @@ +//! Lifecycle-local persistence and serialized refresh transactions for MCP OAuth credentials. + +use std::sync::Arc; +use std::time::Duration; +use std::time::Instant; +use std::time::SystemTime; +use std::time::UNIX_EPOCH; + +use anyhow::Context; +use anyhow::Result; +use codex_config::types::AuthKeyringBackendKind; +use codex_keyring_store::DefaultKeyringStore; +use codex_keyring_store::KeyringStore; +use oauth2::TokenResponse; +use rmcp::transport::auth::AuthorizationManager; +use rmcp::transport::auth::CredentialStore as _; +use rmcp::transport::auth::InMemoryCredentialStore; +use rmcp::transport::auth::StoredCredentials; +use tokio::sync::Mutex; +use tokio::time::timeout; +use tracing::debug; +use tracing::warn; + +use super::ResolvedOAuthCredentialStore; +use super::StoredOAuthTokens; +use super::WrappedOAuthTokenResponse; +use super::compute_expires_at_millis; +use super::compute_store_key; +use super::delete_oauth_tokens_from_direct_keyring; +use super::delete_oauth_tokens_from_file; +use super::delete_oauth_tokens_from_secrets_keyring; +use super::load_oauth_tokens_from_file; +use super::load_oauth_tokens_from_keyring; +use super::refresh_lock::RefreshCredentialLock; +use super::save_oauth_tokens_to_file; +use super::save_oauth_tokens_with_keyring; +use super::token_needs_refresh; + +const REFRESH_REQUEST_TIMEOUT: Duration = Duration::from_secs(45); + +#[derive(Clone)] +pub(crate) struct OAuthPersistor { + inner: Arc, +} + +struct OAuthPersistorInner { + server_name: String, + url: String, + authorization_manager: Arc>, + credential_store: ResolvedOAuthCredentialStore, + last_credentials: Mutex>, +} + +impl OAuthPersistor { + pub(crate) fn new( + server_name: String, + url: String, + authorization_manager: Arc>, + credential_store: ResolvedOAuthCredentialStore, + initial_credentials: Option, + ) -> Self { + Self { + inner: Arc::new(OAuthPersistorInner { + server_name, + url, + authorization_manager, + credential_store, + last_credentials: Mutex::new(initial_credentials), + }), + } + } + + /// Persists the latest stored credentials if they have changed. + /// Deletes the credentials if they are no longer present. + pub(crate) async fn persist_if_needed(&self) -> Result<()> { + self.persist_if_needed_with_keyring_store(&DefaultKeyringStore) + .await + } + + #[expect( + clippy::await_holding_invalid_type, + reason = "AuthorizationManager async access must be serialized through its mutex" + )] + pub(super) async fn persist_if_needed_with_keyring_store( + &self, + keyring_store: &K, + ) -> Result<()> { + let (client_id, maybe_credentials) = { + let manager = self.inner.authorization_manager.clone(); + let guard = manager.lock().await; + guard.get_credentials().await + }?; + + match maybe_credentials { + Some(credentials) => { + let mut last_credentials = self.inner.last_credentials.lock().await; + let new_token_response = WrappedOAuthTokenResponse(credentials.clone()); + let same_token = last_credentials + .as_ref() + .map(|prev| prev.token_response == new_token_response) + .unwrap_or(false); + let expires_at = if same_token { + last_credentials.as_ref().and_then(|prev| prev.expires_at) + } else { + compute_expires_at_millis(&credentials) + }; + let stored = StoredOAuthTokens { + server_name: self.inner.server_name.clone(), + url: self.inner.url.clone(), + client_id, + token_response: new_token_response, + expires_at, + }; + if last_credentials.as_ref() != Some(&stored) { + match self.inner.credential_store { + ResolvedOAuthCredentialStore::File => save_oauth_tokens_to_file(&stored)?, + ResolvedOAuthCredentialStore::Keyring(keyring_backend_kind) => { + save_oauth_tokens_with_keyring( + keyring_store, + keyring_backend_kind, + &self.inner.server_name, + &stored, + )?; + } + } + *last_credentials = Some(stored); + } + } + None => { + let mut last_serialized = self.inner.last_credentials.lock().await; + if last_serialized.take().is_some() + && let Err(error) = match self.inner.credential_store { + ResolvedOAuthCredentialStore::File => { + let key = compute_store_key(&self.inner.server_name, &self.inner.url)?; + delete_oauth_tokens_from_file(&key).map(|_| ()) + } + ResolvedOAuthCredentialStore::Keyring(AuthKeyringBackendKind::Direct) => { + delete_oauth_tokens_from_direct_keyring( + keyring_store, + &self.inner.server_name, + &self.inner.url, + ) + .map(|_| ()) + } + ResolvedOAuthCredentialStore::Keyring(AuthKeyringBackendKind::Secrets) => { + delete_oauth_tokens_from_secrets_keyring( + keyring_store, + &self.inner.server_name, + &self.inner.url, + ) + .map(|_| ()) + } + } + { + warn!( + "failed to remove OAuth tokens for server {}: {error}", + self.inner.server_name + ); + } + } + } + + Ok(()) + } + + pub(crate) async fn refresh_if_needed(&self) -> Result<()> { + self.refresh_if_needed_with_keyring_store(&DefaultKeyringStore) + .await + } + + pub(super) async fn refresh_if_needed_with_keyring_store( + &self, + keyring_store: &K, + ) -> Result<()> { + self.refresh_if_needed_with_keyring_store_and_timeout( + keyring_store, + REFRESH_REQUEST_TIMEOUT, + ) + .await + } + + pub(super) async fn refresh_if_needed_with_keyring_store_and_timeout< + K: KeyringStore + Clone + 'static, + >( + &self, + keyring_store: &K, + refresh_request_timeout: Duration, + ) -> Result<()> { + let expires_at = { + let guard = self.inner.last_credentials.lock().await; + guard.as_ref().and_then(|tokens| tokens.expires_at) + }; + + if !token_needs_refresh(expires_at) { + return Ok(()); + } + + self.run_owned_refresh_transaction(keyring_store.clone(), refresh_request_timeout) + .await + } + + async fn run_owned_refresh_transaction( + &self, + keyring_store: K, + refresh_request_timeout: Duration, + ) -> Result<()> { + let persistor = self.clone(); + let server_name = self.inner.server_name.clone(); + // Once the provider may consume a rotating refresh token, dropping the caller's future + // must not also drop refresh plus persistence. Dropping this JoinHandle detaches the task, + // which continues under the credential lock until its explicit lock/provider bounds. + // + // A provider timeout deliberately leaves the outcome unknown, releases the lock, and + // permits a later serialized retry. Some providers accept the previous token during a + // grace period; otherwise that retry surfaces reauthorization. We accept that residual + // token-family-revocation risk rather than holding the lock indefinitely. + tokio::spawn(async move { + persistor + .refresh_transaction(&keyring_store, refresh_request_timeout) + .await + }) + .await + .with_context(|| format!("OAuth refresh task failed for server {server_name}"))? + } + + #[expect( + clippy::await_holding_invalid_type, + reason = "AuthorizationManager async access must be serialized through its mutex" + )] + #[tracing::instrument( + level = "debug", + skip_all, + fields(server_name = %self.inner.server_name), + err + )] + async fn refresh_transaction( + &self, + keyring_store: &K, + refresh_request_timeout: Duration, + ) -> Result<()> { + let transaction_started_at = Instant::now(); + let lock_started_at = Instant::now(); + debug!("waiting for the MCP OAuth credential transaction lock"); + let key = compute_store_key(&self.inner.server_name, &self.inner.url)?; + let _lock = RefreshCredentialLock::acquire(&key).await?; + debug!( + lock_wait_ms = lock_started_at.elapsed().as_millis(), + "acquired the MCP OAuth credential transaction lock" + ); + // The refresh transaction must stay on the store that supplied its snapshot. Falling back + // here could replay an older rotating refresh token from the other store. We assume store + // availability is stable for this client lifecycle and surface violations of that + // assumption instead of switching stores. + let latest = match self.inner.credential_store { + ResolvedOAuthCredentialStore::File => { + load_oauth_tokens_from_file(&self.inner.server_name, &self.inner.url) + .context("failed to reread OAuth tokens from resolved file storage")? + } + ResolvedOAuthCredentialStore::Keyring(keyring_backend_kind) => { + load_oauth_tokens_from_keyring( + keyring_store, + keyring_backend_kind, + &self.inner.server_name, + &self.inner.url, + ) + .context( + "failed to reread OAuth tokens from resolved keyring storage; refusing file fallback", + )? + } + }; + + // The pre-lock snapshot only decides whether a refresh transaction might be needed. Once + // the lock is held, this reread is authoritative: adopt it before deciding whether to + // refresh so this process never sends a refresh token superseded by another process. + let Some(latest) = latest else { + self.clear_manager_credentials().await; + let mut last_credentials = self.inner.last_credentials.lock().await; + *last_credentials = None; + anyhow::bail!( + "OAuth tokens for server {} were removed before refresh; authorization required", + self.inner.server_name + ); + }; + + if !token_needs_refresh(latest.expires_at) { + self.adopt_credentials(latest).await?; + return Ok(()); + } + + self.adopt_credentials(latest).await?; + + { + let manager = self.inner.authorization_manager.clone(); + let guard = manager.lock().await; + let provider_started_at = Instant::now(); + debug!( + timeout_ms = refresh_request_timeout.as_millis(), + "requesting refreshed MCP OAuth credentials from the provider" + ); + match timeout(refresh_request_timeout, guard.refresh_token()).await { + Ok(Ok(_token_response)) => { + debug!( + provider_elapsed_ms = provider_started_at.elapsed().as_millis(), + "received refreshed MCP OAuth credentials from the provider" + ); + } + Ok(Err(error)) => { + warn!( + provider_elapsed_ms = provider_started_at.elapsed().as_millis(), + error = %error, + "MCP OAuth provider refresh failed" + ); + return Err(error).with_context(|| { + format!( + "failed to refresh OAuth tokens for server {}", + self.inner.server_name + ) + }); + } + Err(_) => { + warn!( + provider_elapsed_ms = provider_started_at.elapsed().as_millis(), + timeout_ms = refresh_request_timeout.as_millis(), + "MCP OAuth provider refresh timed out; the outcome is unknown and a later serialized retry is permitted" + ); + anyhow::bail!( + "timed out after {refresh_request_timeout:?} refreshing OAuth tokens for server {}", + self.inner.server_name + ); + } + } + } + + // Once the provider returns a rotated token, persistence must finish before the credential + // lock is released. In particular, caller startup deadlines must not cancel this step. + let result = self + .persist_if_needed_with_keyring_store(keyring_store) + .await; + if result.is_ok() { + debug!( + transaction_elapsed_ms = transaction_started_at.elapsed().as_millis(), + "completed the MCP OAuth refresh transaction" + ); + } + result + } + + async fn adopt_credentials(&self, tokens: StoredOAuthTokens) -> Result<()> { + install_tokens_in_manager(&self.inner.authorization_manager, &tokens).await?; + let mut last_credentials = self.inner.last_credentials.lock().await; + *last_credentials = Some(tokens); + Ok(()) + } + + async fn clear_manager_credentials(&self) { + let manager = self.inner.authorization_manager.clone(); + let mut guard = manager.lock().await; + guard.set_credential_store(InMemoryCredentialStore::new()); + } +} + +#[expect( + clippy::await_holding_invalid_type, + reason = "AuthorizationManager async access must be serialized through its mutex" +)] +async fn install_tokens_in_manager( + authorization_manager: &Arc>, + tokens: &StoredOAuthTokens, +) -> Result<()> { + let store = InMemoryCredentialStore::new(); + store + .save(stored_credentials_from_tokens(tokens)) + .await + .context("failed to stage OAuth tokens for authorization manager")?; + + let manager = authorization_manager.clone(); + let mut guard = manager.lock().await; + guard.set_credential_store(store); + // TODO(stevenlee): RMCP's `initialize_from_store` updates the credential store and client ID + // but not its private `current_scopes`. Credential adoption can therefore leave scope-upgrade + // state stale until RMCP exposes an adoption API that synchronizes both. + guard + .initialize_from_store() + .await + .context("failed to adopt refreshed OAuth tokens")?; + Ok(()) +} + +fn stored_credentials_from_tokens(tokens: &StoredOAuthTokens) -> StoredCredentials { + let token_response = tokens.token_response.0.clone(); + let granted_scopes = token_response + .scopes() + .map(|scopes| scopes.iter().map(|scope| scope.to_string()).collect()) + .unwrap_or_default(); + let token_received_at = SystemTime::now() + .duration_since(UNIX_EPOCH) + .ok() + .map(|duration| duration.as_secs()); + + StoredCredentials::new( + tokens.client_id.clone(), + Some(token_response), + granted_scopes, + token_received_at, + ) +} diff --git a/codex-rs/rmcp-client/src/oauth/refresh_lock.rs b/codex-rs/rmcp-client/src/oauth/refresh_lock.rs new file mode 100644 index 000000000000..b540066cd110 --- /dev/null +++ b/codex-rs/rmcp-client/src/oauth/refresh_lock.rs @@ -0,0 +1,113 @@ +//! Cross-process serialization for one MCP OAuth credential's refresh transaction. +//! +//! The guard is intentionally acquired before the authoritative credential reread and retained +//! through provider refresh and persistence. This prevents two processes from replaying the same +//! rotating refresh token or observing a partially persisted transaction. + +use anyhow::Context; +use anyhow::Result; +use codex_utils_home_dir::find_codex_home; +use sha2::Digest; +use sha2::Sha256; +use std::fs; +use std::fs::File; +use std::fs::OpenOptions; +use std::path::Path; +use std::path::PathBuf; +use std::time::Duration; +use tokio::time::sleep; +use tokio::time::timeout; + +const REFRESH_LOCK_DIR: &str = "mcp-oauth-refresh-locks"; +const REFRESH_LOCK_ACQUIRE_TIMEOUT: Duration = Duration::from_secs(/*secs*/ 60); +const REFRESH_LOCK_RETRY_SLEEP: Duration = Duration::from_millis(/*millis*/ 50); +// Keep this internal target stable so diagnostics and cross-process tests can distinguish actual +// WouldBlock contention from a contender that merely started late and observed persisted tokens. +const LOCK_CONTENTION_EVENT_TARGET: &str = "codex_rmcp_client::oauth::refresh_lock::contention"; + +pub(super) struct RefreshCredentialLock { + _file: File, +} + +impl RefreshCredentialLock { + pub(super) async fn acquire(store_key: &str) -> Result { + let codex_home = find_codex_home()?; + Self::acquire_in(&codex_home, store_key, REFRESH_LOCK_ACQUIRE_TIMEOUT).await + } + + async fn acquire_in( + codex_home: &Path, + store_key: &str, + acquire_timeout: Duration, + ) -> Result { + let path = refresh_lock_path(codex_home, store_key); + if let Some(parent) = path.parent() { + fs::create_dir_all(parent)?; + } + + let file = OpenOptions::new() + .read(true) + .write(true) + .create(true) + .truncate(false) + .open(&path) + .with_context(|| format!("failed to open OAuth refresh lock {}", path.display()))?; + + // Bound every contender, but keep the acquired lock for the full provider request and + // persistence transaction. Releasing it while awaiting the provider would allow concurrent + // use of a rotating refresh token. + let mut reported_contention = false; + match timeout(acquire_timeout, async { + loop { + match file.try_lock() { + Ok(()) => return Ok(()), + Err(std::fs::TryLockError::WouldBlock) => { + if !reported_contention { + tracing::debug!( + target: LOCK_CONTENTION_EVENT_TARGET, + lock_path = %path.display(), + "waiting for another process to finish refreshing MCP OAuth credentials" + ); + reported_contention = true; + } + sleep(REFRESH_LOCK_RETRY_SLEEP).await; + } + Err(error) => return Err(std::io::Error::from(error)), + } + } + }) + .await + { + Ok(Ok(())) => {} + Ok(Err(error)) => { + return Err(error).with_context(|| { + format!("failed to lock OAuth refresh lock {}", path.display()) + }); + } + Err(_) => anyhow::bail!( + "timed out after {acquire_timeout:?} waiting for OAuth refresh lock {}", + path.display() + ), + } + + Ok(Self { _file: file }) + } +} + +fn refresh_lock_path(codex_home: &Path, store_key: &str) -> PathBuf { + // Credential coordination is deliberately scoped to the active CODEX_HOME, alongside File + // and Secrets state. Coordinating the process-global Direct keyring across distinct homes + // would require a separately defined global lock namespace and is outside this transaction. + // TODO(stevenlee): define a safe per-user, cross-platform rendezvous before extending Direct + // keyring coordination across distinct CODEX_HOME values. + let mut hasher = Sha256::new(); + hasher.update(store_key.as_bytes()); + let digest = hasher.finalize(); + codex_home + .join(REFRESH_LOCK_DIR) + .join(format!("{digest:x}.lock")) +} + +#[cfg(test)] +#[path = "refresh_lock_tests.rs"] +mod tests; diff --git a/codex-rs/rmcp-client/src/oauth/refresh_lock_tests.rs b/codex-rs/rmcp-client/src/oauth/refresh_lock_tests.rs new file mode 100644 index 000000000000..2b90e73996cc --- /dev/null +++ b/codex-rs/rmcp-client/src/oauth/refresh_lock_tests.rs @@ -0,0 +1,42 @@ +use super::RefreshCredentialLock; +use anyhow::Result; +use std::time::Duration; +use tempfile::tempdir; + +#[tokio::test] +async fn acquisition_times_out_without_stealing() -> Result<()> { + let codex_home = tempdir()?; + let store_key = "test-store-key"; + let held_lock = RefreshCredentialLock::acquire_in( + codex_home.path(), + store_key, + Duration::from_millis(/*millis*/ 100), + ) + .await?; + + let error = match RefreshCredentialLock::acquire_in( + codex_home.path(), + store_key, + Duration::from_millis(/*millis*/ 50), + ) + .await + { + Ok(_) => panic!("contending lock acquisition should time out"), + Err(error) => error, + }; + assert!( + error + .to_string() + .contains("timed out after 50ms waiting for OAuth refresh lock"), + "unexpected error: {error:#}" + ); + + drop(held_lock); + let _reacquired = RefreshCredentialLock::acquire_in( + codex_home.path(), + store_key, + Duration::from_millis(/*millis*/ 100), + ) + .await?; + Ok(()) +} diff --git a/codex-rs/rmcp-client/src/oauth/resolved_store.rs b/codex-rs/rmcp-client/src/oauth/resolved_store.rs new file mode 100644 index 000000000000..3919521883ba --- /dev/null +++ b/codex-rs/rmcp-client/src/oauth/resolved_store.rs @@ -0,0 +1,140 @@ +//! Resolves the configured MCP OAuth store and pins that concrete source for one client lifecycle. + +use anyhow::Context; +use anyhow::Result; +use codex_config::types::AuthKeyringBackendKind; +use codex_config::types::OAuthCredentialsStoreMode; +use codex_keyring_store::DefaultKeyringStore; +use codex_keyring_store::KeyringStore; +use tracing::warn; + +use super::StoredOAuthTokens; +use super::load_oauth_tokens_from_file; +use super::load_oauth_tokens_from_keyring; + +/// Concrete credential store resolved for one MCP OAuth client lifecycle. +/// +/// This is intentionally not durable. `Auto` may resolve differently in a later process, but a +/// client that loaded credentials from one store must reread, refresh, persist, and remove only +/// through that store. A mid-lifecycle backend failure is unexpected and must return an error +/// rather than falling back to another possibly stale refresh token. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum ResolvedOAuthCredentialStore { + File, + Keyring(AuthKeyringBackendKind), +} + +#[derive(Debug)] +pub(crate) struct LoadedOAuthTokens { + pub(crate) tokens: StoredOAuthTokens, + pub(crate) store: ResolvedOAuthCredentialStore, +} + +pub(crate) fn load_oauth_tokens_with_source( + server_name: &str, + url: &str, + store_mode: OAuthCredentialsStoreMode, + keyring_backend_kind: AuthKeyringBackendKind, +) -> Result> { + let keyring_store = DefaultKeyringStore; + load_oauth_tokens_with_source_and_keyring_store( + &keyring_store, + server_name, + url, + store_mode, + keyring_backend_kind, + ) +} + +pub(super) fn load_oauth_tokens_with_source_and_keyring_store( + keyring_store: &K, + server_name: &str, + url: &str, + store_mode: OAuthCredentialsStoreMode, + keyring_backend_kind: AuthKeyringBackendKind, +) -> Result> { + match store_mode { + OAuthCredentialsStoreMode::Auto => load_oauth_tokens_from_keyring_with_fallback_to_file( + keyring_store, + keyring_backend_kind, + server_name, + url, + ), + OAuthCredentialsStoreMode::File => Ok(load_oauth_tokens_from_file(server_name, url)?.map( + |tokens| LoadedOAuthTokens { + tokens, + store: ResolvedOAuthCredentialStore::File, + }, + )), + OAuthCredentialsStoreMode::Keyring => Ok(load_oauth_tokens_from_keyring( + keyring_store, + keyring_backend_kind, + server_name, + url, + ) + .with_context(|| "failed to read OAuth tokens from keyring".to_string())? + .map(|tokens| LoadedOAuthTokens { + tokens, + store: ResolvedOAuthCredentialStore::Keyring(keyring_backend_kind), + })), + } +} + +pub(crate) fn load_oauth_tokens_from_resolved_store( + server_name: &str, + url: &str, + store: ResolvedOAuthCredentialStore, +) -> Result> { + match store { + ResolvedOAuthCredentialStore::File => load_oauth_tokens_from_file(server_name, url) + .context("failed to read OAuth tokens from resolved file storage"), + ResolvedOAuthCredentialStore::Keyring(keyring_backend_kind) => { + load_oauth_tokens_from_keyring( + &DefaultKeyringStore, + keyring_backend_kind, + server_name, + url, + ) + .context( + "failed to read OAuth tokens from resolved keyring storage; refusing file fallback", + ) + } + } +} + +pub(super) fn load_oauth_tokens_from_keyring_with_fallback_to_file< + K: KeyringStore + Clone + 'static, +>( + keyring_store: &K, + keyring_backend_kind: AuthKeyringBackendKind, + server_name: &str, + url: &str, +) -> Result> { + // Auto remains keyring-first at lifecycle startup. The returned source is then pinned by the + // client transport recipe and OAuth persistor so retries, recovery, and refresh work cannot + // hot-switch stores. + // TODO(stevenlee): Different processes can still resolve Auto to different stores when + // keyring availability differs. Solving that safely requires durable backend selection or + // reconciliation of legacy entries and is intentionally outside this stack. + match load_oauth_tokens_from_keyring(keyring_store, keyring_backend_kind, server_name, url) { + Ok(Some(tokens)) => Ok(Some(LoadedOAuthTokens { + tokens, + store: ResolvedOAuthCredentialStore::Keyring(keyring_backend_kind), + })), + Ok(None) => Ok( + load_oauth_tokens_from_file(server_name, url)?.map(|tokens| LoadedOAuthTokens { + tokens, + store: ResolvedOAuthCredentialStore::File, + }), + ), + Err(error) => { + warn!("failed to read OAuth tokens from keyring: {error}"); + Ok(load_oauth_tokens_from_file(server_name, url) + .with_context(|| format!("failed to read OAuth tokens from keyring: {error}"))? + .map(|tokens| LoadedOAuthTokens { + tokens, + store: ResolvedOAuthCredentialStore::File, + })) + } + } +} diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index c6527990fac0..eb3e0deb53f6 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -3,6 +3,7 @@ use std::ffi::OsString; use std::future::Future; use std::io; use std::sync::Arc; +use std::sync::OnceLock; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering; use std::time::Duration; @@ -64,9 +65,12 @@ use crate::elicitation_client_service::ElicitationClientService; use crate::http_client_adapter::StreamableHttpClientAdapter; use crate::http_client_adapter::StreamableHttpClientAdapterError; use crate::in_process_transport::InProcessTransportFactory; -use crate::load_oauth_tokens; +use crate::oauth::LoadedOAuthTokens; use crate::oauth::OAuthPersistor; +use crate::oauth::ResolvedOAuthCredentialStore; use crate::oauth::StoredOAuthTokens; +use crate::oauth::load_oauth_tokens_from_resolved_store; +use crate::oauth::load_oauth_tokens_with_source; use crate::oauth_http_client::OAuthHttpClientAdapter; use crate::stdio_server_launcher::StdioServerCommand; use crate::stdio_server_launcher::StdioServerLauncher; @@ -80,6 +84,7 @@ mod streamable_http_retry; use self::streamable_http_retry::HandshakeError; use self::streamable_http_retry::STREAMABLE_HTTP_RETRY_DELAYS_MS; +use self::streamable_http_retry::remaining_initialize_timeout; use self::streamable_http_retry::sleep_with_retry_deadline; enum PendingTransport { @@ -126,6 +131,7 @@ enum TransportRecipe { env_http_headers: Option>, store_mode: OAuthCredentialsStoreMode, keyring_backend_kind: AuthKeyringBackendKind, + resolved_store: Arc>, http_client: Arc, auth_provider: Option, }, @@ -402,6 +408,7 @@ impl RmcpClient { env_http_headers, store_mode, keyring_backend_kind, + resolved_store: Arc::new(OnceLock::new()), http_client, auth_provider, }; @@ -444,11 +451,13 @@ impl RmcpClient { } }; + let mut initialize_deadline = timeout.map(|duration| Instant::now() + duration); let (service, oauth_persistor) = self .connect_pending_transport_with_initialize_retries( pending_transport, client_service.clone(), timeout, + &mut initialize_deadline, ) .await?; @@ -491,7 +500,7 @@ impl RmcpClient { params: Option, timeout: Option, ) -> Result { - self.refresh_oauth_if_needed().await; + self.refresh_oauth_if_needed().await?; let result = self .run_service_operation("tools/list", timeout, move |service| { let params = params.clone(); @@ -508,7 +517,7 @@ impl RmcpClient { params: Option, timeout: Option, ) -> Result { - self.refresh_oauth_if_needed().await; + self.refresh_oauth_if_needed().await?; let result = self .run_service_operation("tools/list", timeout, move |service| { let params = params.clone(); @@ -553,7 +562,7 @@ impl RmcpClient { params: Option, timeout: Option, ) -> Result { - self.refresh_oauth_if_needed().await; + self.refresh_oauth_if_needed().await?; let result = self .run_service_operation("resources/list", timeout, move |service| { let params = params.clone(); @@ -569,7 +578,7 @@ impl RmcpClient { params: Option, timeout: Option, ) -> Result { - self.refresh_oauth_if_needed().await; + self.refresh_oauth_if_needed().await?; let result = self .run_service_operation("resources/templates/list", timeout, move |service| { let params = params.clone(); @@ -585,7 +594,7 @@ impl RmcpClient { params: ReadResourceRequestParams, timeout: Option, ) -> Result { - self.refresh_oauth_if_needed().await; + self.refresh_oauth_if_needed().await?; let result = self .run_service_operation("resources/read", timeout, move |service| { let params = params.clone(); @@ -603,7 +612,7 @@ impl RmcpClient { meta: Option, timeout: Option, ) -> Result { - self.refresh_oauth_if_needed().await; + self.refresh_oauth_if_needed().await?; let arguments = match arguments { Some(Value::Object(map)) => Some(map), Some(other) => { @@ -659,7 +668,7 @@ impl RmcpClient { method: &str, params: Option, ) -> Result<()> { - self.refresh_oauth_if_needed().await; + self.refresh_oauth_if_needed().await?; self.run_service_operation( "notifications/custom", /*timeout*/ None, @@ -689,7 +698,7 @@ impl RmcpClient { method: &str, params: Option, ) -> Result { - self.refresh_oauth_if_needed().await; + self.refresh_oauth_if_needed().await?; let response = self .run_service_operation("requests/custom", /*timeout*/ None, move |service| { let params = params.clone(); @@ -753,12 +762,11 @@ impl RmcpClient { } } - async fn refresh_oauth_if_needed(&self) { - if let Some(runtime) = self.oauth_persistor().await - && let Err(error) = runtime.refresh_if_needed().await - { - warn!("failed to refresh OAuth tokens: {error}"); + async fn refresh_oauth_if_needed(&self) -> Result<()> { + if let Some(runtime) = self.oauth_persistor().await { + runtime.refresh_if_needed().await?; } + Ok(()) } async fn create_pending_transport( @@ -781,6 +789,7 @@ impl RmcpClient { env_http_headers, store_mode, keyring_backend_kind, + resolved_store, http_client, auth_provider, } => { @@ -797,24 +806,50 @@ impl RmcpClient { && auth_provider.is_none() && !default_headers.contains_key(AUTHORIZATION) { - match load_oauth_tokens(server_name, url, *store_mode, *keyring_backend_kind) { - Ok(tokens) => tokens, - Err(err) => { - warn!("failed to read tokens for server `{server_name}`: {err}"); - None + if let Some(store) = resolved_store.get().copied() { + load_oauth_tokens_from_resolved_store(server_name, url, store)? + .map(|tokens| LoadedOAuthTokens { tokens, store }) + } else { + match load_oauth_tokens_with_source( + server_name, + url, + *store_mode, + *keyring_backend_kind, + ) { + Ok(tokens) => { + if let Some(loaded) = tokens.as_ref() { + // Transport retries and session recovery are part of the same + // client lifecycle. Pin the first concrete source in memory so + // rebuilding a transport never re-evaluates Auto and adopts a + // possibly stale credential from another store. + resolved_store.set(loaded.store).map_err(|_| { + anyhow!( + "OAuth credential store resolved concurrently for MCP server `{server_name}`" + ) + })?; + } + tokens + } + Err(err) => { + warn!("failed to read tokens for server `{server_name}`: {err}"); + None + } } } } else { None }; - if let Some(initial_tokens) = initial_oauth_tokens.clone() { + if let Some(LoadedOAuthTokens { + tokens: initial_tokens, + store: credential_store, + }) = initial_oauth_tokens + { match create_oauth_transport_and_runtime( server_name, url, initial_tokens.clone(), - *store_mode, - *keyring_backend_kind, + credential_store, default_headers.clone(), Arc::clone(http_client), ) @@ -880,6 +915,7 @@ impl RmcpClient { pending_transport: PendingTransport, client_service: ElicitationClientService, timeout: Option, + initialize_deadline: &mut Option, ) -> Result<( Arc>, Option, @@ -900,13 +936,26 @@ impl RmcpClient { PendingTransport::StreamableHttpWithOAuth { transport, oauth_persistor, - } => ( - service::serve_client(client_service, transport).boxed(), - Some(oauth_persistor), - ), + } => { + // `startup_timeout_sec` bounds MCP transport setup, retry delays, and the + // initialization handshake. OAuth refresh has independent lock and provider + // request bounds, and persistence must finish after a successful response, so the + // complete refresh transaction is deliberately excluded from that deadline. + let refresh_started_at = Instant::now(); + let refresh_result = oauth_persistor.refresh_if_needed().await; + if let Some(deadline) = initialize_deadline.as_mut() { + *deadline += refresh_started_at.elapsed(); + } + refresh_result?; + ( + service::serve_client(client_service, transport).boxed(), + Some(oauth_persistor), + ) + } }; - let service_result = match timeout { + let handshake_timeout = remaining_initialize_timeout(timeout, *initialize_deadline)?; + let service_result = match handshake_timeout { Some(duration) => match time::timeout(duration, transport).await { Ok(result) => { result.map_err(|source| anyhow::Error::from(HandshakeError { source })) @@ -1124,11 +1173,15 @@ impl RmcpClient { .clone() .ok_or_else(|| anyhow!("MCP client cannot recover before initialize succeeds"))?; let pending_transport = Self::create_pending_transport(&self.transport_recipe).await?; + let mut initialize_deadline = initialize_context + .timeout + .map(|duration| Instant::now() + duration); let (service, oauth_persistor) = self .connect_pending_transport_with_initialize_retries( pending_transport, initialize_context.client_service, initialize_context.timeout, + &mut initialize_deadline, ) .await?; @@ -1157,8 +1210,7 @@ async fn create_oauth_transport_and_runtime( server_name: &str, url: &str, initial_tokens: StoredOAuthTokens, - credentials_store: OAuthCredentialsStoreMode, - keyring_backend_kind: AuthKeyringBackendKind, + credential_store: ResolvedOAuthCredentialStore, default_headers: HeaderMap, http_client: Arc, ) -> Result<( @@ -1202,8 +1254,7 @@ async fn create_oauth_transport_and_runtime( server_name.to_string(), url.to_string(), auth_manager, - credentials_store, - keyring_backend_kind, + credential_store, Some(initial_tokens), ); diff --git a/codex-rs/rmcp-client/src/streamable_http_retry.rs b/codex-rs/rmcp-client/src/streamable_http_retry.rs index 73da95de58ec..2794489b2a76 100644 --- a/codex-rs/rmcp-client/src/streamable_http_retry.rs +++ b/codex-rs/rmcp-client/src/streamable_http_retry.rs @@ -28,6 +28,7 @@ impl RmcpClient { initial_transport: PendingTransport, client_service: ElicitationClientService, timeout: Option, + initialize_deadline: &mut Option, ) -> Result<( Arc>, Option, @@ -37,7 +38,6 @@ impl RmcpClient { PendingTransport::StreamableHttp { .. } | PendingTransport::StreamableHttpWithOAuth { .. } => true, }; - let retry_deadline = timeout.map(|duration| Instant::now() + duration); let mut pending_transport = Some(initial_transport); for (attempt, retry_delay_ms) in STREAMABLE_HTTP_RETRY_DELAYS_MS @@ -50,7 +50,7 @@ impl RmcpClient { let transport = match pending_transport.take() { Some(transport) => transport, None => { - let remaining = remaining_initialize_timeout(timeout, retry_deadline)?; + let remaining = remaining_initialize_timeout(timeout, *initialize_deadline)?; match remaining { Some(remaining) => time::timeout( remaining, @@ -62,12 +62,11 @@ impl RmcpClient { } } }; - let attempt_timeout = remaining_initialize_timeout(timeout, retry_deadline)?; - match Self::connect_pending_transport( transport, client_service.clone(), - attempt_timeout, + timeout, + initialize_deadline, ) .await { @@ -84,7 +83,7 @@ impl RmcpClient { error = %error, "streamable HTTP MCP initialize failed with a retryable error; retrying" ); - if !sleep_with_retry_deadline(delay, retry_deadline).await { + if !sleep_with_retry_deadline(delay, *initialize_deadline).await { let duration = timeout.unwrap_or(delay); return Err(anyhow!( "timed out handshaking with MCP server after {duration:?}" @@ -194,7 +193,7 @@ fn is_retryable_http_status(status: StatusCode) -> bool { ) } -fn remaining_initialize_timeout( +pub(super) fn remaining_initialize_timeout( timeout: Option, deadline: Option, ) -> Result> { @@ -209,7 +208,10 @@ fn remaining_initialize_timeout( } } -fn initialize_timeout_error(timeout: Option, fallback: Duration) -> anyhow::Error { +pub(super) fn initialize_timeout_error( + timeout: Option, + fallback: Duration, +) -> anyhow::Error { let duration = timeout.unwrap_or(fallback); anyhow!("timed out handshaking with MCP server after {duration:?}") } diff --git a/codex-rs/rmcp-client/tests/streamable_http_oauth_startup.rs b/codex-rs/rmcp-client/tests/streamable_http_oauth_startup.rs index 866daa67a4c0..d578f40a3ebd 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_oauth_startup.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_oauth_startup.rs @@ -1,9 +1,19 @@ mod streamable_http_test_support; +use std::fs; +use std::path::Path; +use std::path::PathBuf; +use std::sync::Arc; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; use std::time::Duration; use std::time::SystemTime; use std::time::UNIX_EPOCH; +use axum::Json; +use axum::Router; +use axum::extract::State; +use axum::routing::post; use codex_config::types::AuthKeyringBackendKind; use codex_config::types::OAuthCredentialsStoreMode; use codex_exec_server::Environment; @@ -22,7 +32,18 @@ use rmcp::transport::auth::VendorExtraTokenFields; use serde_json::Value; use serde_json::json; use tempfile::TempDir; +use tokio::net::TcpListener; use tokio::process::Command; +use tokio::sync::Notify; +use tokio::sync::Semaphore; +use tokio::task::JoinHandle; +use tracing::Event; +use tracing::Id; +use tracing::Metadata; +use tracing::Subscriber; +use tracing::span::Attributes; +use tracing::span::Record; +use tracing::subscriber::Interest; use wiremock::Mock; use wiremock::MockServer; use wiremock::Request; @@ -33,18 +54,180 @@ use wiremock::matchers::method; use wiremock::matchers::path; use streamable_http_test_support::initialize_client; +use streamable_http_test_support::initialize_client_with_timeout; const SERVER_NAME: &str = "test-streamable-http-oauth-startup"; const EXPIRED_ACCESS_TOKEN: &str = "expired-access-token"; const REFRESH_TOKEN: &str = "valid-refresh-token"; const REFRESHED_ACCESS_TOKEN: &str = "refreshed-access-token"; +const ROTATED_REFRESH_TOKEN: &str = "rotated-refresh-token"; +const CHILD_CONTENTION_FILE_ENV: &str = "MCP_TEST_OAUTH_STARTUP_CONTENTION_FILE"; +const CHILD_READY_FILE_ENV: &str = "MCP_TEST_OAUTH_STARTUP_READY_FILE"; +const CHILD_RELEASE_FILE_ENV: &str = "MCP_TEST_OAUTH_STARTUP_RELEASE_FILE"; +const PREFLIGHT_REFRESH_ERROR: &str = "preflight refresh failed distinctly"; const CHILD_SERVER_URL_ENV: &str = "MCP_TEST_OAUTH_STARTUP_SERVER_URL"; +// This mirrors the private event target in oauth::refresh_lock without exposing test-only crate API. +const LOCK_CONTENTION_EVENT_TARGET: &str = "codex_rmcp_client::oauth::refresh_lock::contention"; const UNREFRESHABLE_SERVER_URL: &str = "https://unrefreshable.example/mcp"; const UNEXPIRED_SERVER_URL: &str = "https://unexpired.example/mcp"; const REFRESHABLE_SERVER_URL: &str = "https://refreshable.example/mcp"; +#[derive(Clone)] +struct GatedRefreshState { + request_count: Arc, + request_started: Arc, + response_release: Arc, +} + +struct GatedRefreshServer { + token_endpoint: String, + state: GatedRefreshState, + task: JoinHandle<()>, +} + +impl GatedRefreshServer { + async fn start() -> anyhow::Result { + let state = GatedRefreshState { + request_count: Arc::new(AtomicUsize::new(/*v*/ 0)), + request_started: Arc::new(Notify::new()), + response_release: Arc::new(Semaphore::new(/*permits*/ 0)), + }; + let router = Router::new() + .route("/oauth/token", post(gated_refresh_response)) + .with_state(state.clone()); + let listener = TcpListener::bind("127.0.0.1:0").await?; + let address = listener.local_addr()?; + let task = tokio::spawn(async move { + if let Err(error) = axum::serve(listener, router).await { + panic!("gated refresh server failed: {error}"); + } + }); + Ok(Self { + token_endpoint: format!("http://{address}/oauth/token"), + state, + task, + }) + } + + async fn wait_until_request_started(&self) -> anyhow::Result<()> { + let notified = self.state.request_started.notified(); + if self.request_count() == 0 { + tokio::time::timeout(Duration::from_secs(/*secs*/ 10), notified) + .await + .map_err(|_| anyhow::anyhow!("provider refresh request did not start"))?; + } + Ok(()) + } + + fn release_responses(&self) { + // Two permits also let the no-lock negative control exit cleanly after issuing two refresh + // requests. The passing path consumes one permit because only the lock owner calls out. + self.state.response_release.add_permits(/*n*/ 2); + } + + fn request_count(&self) -> usize { + self.state.request_count.load(Ordering::SeqCst) + } +} + +impl Drop for GatedRefreshServer { + fn drop(&mut self) { + self.task.abort(); + } +} + +async fn gated_refresh_response( + State(state): State, + body: String, +) -> Json { + assert!(body.contains("grant_type=refresh_token")); + assert!(body.contains(&format!("refresh_token={REFRESH_TOKEN}"))); + state.request_count.fetch_add(/*val*/ 1, Ordering::SeqCst); + state.request_started.notify_one(); + let Ok(permit) = state.response_release.acquire().await else { + panic!("gated refresh server closed its response semaphore"); + }; + permit.forget(); + Json(json!({ + "access_token": REFRESHED_ACCESS_TOKEN, + "token_type": "Bearer", + "expires_in": 7200, + "refresh_token": ROTATED_REFRESH_TOKEN, + })) +} + +struct LockContentionMarkerSubscriber { + marker_file: PathBuf, +} + +impl Subscriber for LockContentionMarkerSubscriber { + fn enabled(&self, metadata: &Metadata<'_>) -> bool { + metadata.target() == LOCK_CONTENTION_EVENT_TARGET + } + + fn register_callsite(&self, metadata: &'static Metadata<'static>) -> Interest { + if self.enabled(metadata) { + Interest::always() + } else { + Interest::never() + } + } + + fn max_level_hint(&self) -> Option { + Some(tracing::level_filters::LevelFilter::DEBUG) + } + + fn new_span(&self, _span: &Attributes<'_>) -> Id { + // This subscriber enables only the contention event callsite, so it never observes spans. + Id::from_u64(/*u*/ 1) + } + + fn record(&self, _span: &Id, _values: &Record<'_>) {} + + fn record_follows_from(&self, _span: &Id, _follows: &Id) {} + + fn event(&self, event: &Event<'_>) { + if event.metadata().target() == LOCK_CONTENTION_EVENT_TARGET + && let Err(error) = fs::write(&self.marker_file, b"contended") + { + panic!("failed to write refresh-lock contention marker: {error}"); + } + } + + fn enter(&self, _span: &Id) {} + + fn exit(&self, _span: &Id) {} +} + +async fn wait_for_marker(path: &Path, timeout_message: &str) -> anyhow::Result<()> { + tokio::time::timeout(Duration::from_secs(/*secs*/ 10), async { + while !path.exists() { + tokio::time::sleep(Duration::from_millis(/*millis*/ 10)).await; + } + }) + .await + .map_err(|_| anyhow::anyhow!("{timeout_message}")) +} + +fn oauth_concurrency_child_command( + codex_home: &Path, + server_url: &str, + ready_file: &Path, + release_file: &Path, +) -> anyhow::Result { + let mut command = Command::new(std::env::current_exe()?); + command + .args(["oauth_concurrency_client_child", "--exact", "--ignored"]) + .env("CODEX_HOME", codex_home) + .env(CHILD_SERVER_URL_ENV, server_url) + .env(CHILD_READY_FILE_ENV, ready_file) + .env(CHILD_RELEASE_FILE_ENV, release_file) + .kill_on_drop(true); + Ok(command) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] -async fn refreshes_expired_persisted_token_before_initialize() -> anyhow::Result<()> { +async fn startup_refresh_does_not_consume_handshake_timeout() -> anyhow::Result<()> { let server = MockServer::start().await; Mock::given(method("GET")) .and(path("/.well-known/oauth-authorization-server/mcp")) @@ -62,12 +245,18 @@ async fn refreshes_expired_persisted_token_before_initialize() -> anyhow::Result .and(body_string_contains(format!( "refresh_token={REFRESH_TOKEN}" ))) - .respond_with(ResponseTemplate::new(200).set_body_json(json!({ - "access_token": REFRESHED_ACCESS_TOKEN, - "token_type": "Bearer", - "expires_in": 7200, - "refresh_token": REFRESH_TOKEN, - }))) + // The provider takes longer than the configured MCP handshake timeout. Refresh has its own + // bound, so this delay must not leave the subsequent handshake with an expired budget. + .respond_with( + ResponseTemplate::new(200) + .set_delay(Duration::from_millis(1_500)) + .set_body_json(json!({ + "access_token": REFRESHED_ACCESS_TOKEN, + "token_type": "Bearer", + "expires_in": 7200, + "refresh_token": REFRESH_TOKEN, + })), + ) .expect(1) .mount(&server) .await; @@ -121,6 +310,222 @@ async fn refreshes_expired_persisted_token_before_initialize() -> anyhow::Result Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn concurrent_file_mode_startup_refreshes_once() -> anyhow::Result<()> { + let server = MockServer::start().await; + let refresh_server = GatedRefreshServer::start().await?; + let codex_home = TempDir::new()?; + let server_url = format!("{}/mcp", server.uri()); + Mock::given(method("GET")) + .and(path("/.well-known/oauth-authorization-server/mcp")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "authorization_endpoint": format!("{}/oauth/authorize", server.uri()), + "token_endpoint": refresh_server.token_endpoint.clone(), + "scopes_supported": [""], + }))) + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path("/mcp")) + .and(header( + "authorization", + format!("Bearer {REFRESHED_ACCESS_TOKEN}"), + )) + .respond_with(|request: &Request| { + let body: Value = request.body_json().expect("valid JSON-RPC request"); + match body.get("method").and_then(Value::as_str) { + Some("initialize") => ResponseTemplate::new(200).set_body_json(json!({ + "jsonrpc": "2.0", + "id": body.get("id").cloned().unwrap_or(Value::Null), + "result": { + "protocolVersion": body + .pointer("/params/protocolVersion") + .cloned() + .unwrap_or_else(|| json!("2025-06-18")), + "capabilities": {}, + "serverInfo": { + "name": "oauth-startup-test", + "version": "0.0.0-test", + }, + }, + })), + Some("notifications/initialized") => ResponseTemplate::new(202), + method => ResponseTemplate::new(400) + .set_body_string(format!("unexpected JSON-RPC method: {method:?}")), + } + }) + .expect(4) + .mount(&server) + .await; + + let seed_status = Command::new(std::env::current_exe()?) + .args(["oauth_concurrency_seed_child", "--exact", "--ignored"]) + .env("CODEX_HOME", codex_home.path()) + .env(CHILD_SERVER_URL_ENV, &server_url) + .status() + .await?; + assert!( + seed_status.success(), + "OAuth concurrency seed child failed: {seed_status}" + ); + + let first_ready_file = codex_home.path().join("oauth-client-first.ready"); + let first_release_file = codex_home.path().join("oauth-client-first.release"); + let second_ready_file = codex_home.path().join("oauth-client-second.ready"); + let second_release_file = codex_home.path().join("oauth-client-second.release"); + let contention_file = codex_home.path().join("oauth-client-second.contended"); + let mut first_child = oauth_concurrency_child_command( + codex_home.path(), + &server_url, + &first_ready_file, + &first_release_file, + )? + .spawn()?; + let mut second_command = oauth_concurrency_child_command( + codex_home.path(), + &server_url, + &second_ready_file, + &second_release_file, + )?; + second_command.env(CHILD_CONTENTION_FILE_ENV, &contention_file); + let mut second_child = second_command.spawn()?; + + wait_for_marker( + &first_ready_file, + "first OAuth concurrency child did not become ready", + ) + .await?; + wait_for_marker( + &second_ready_file, + "second OAuth concurrency child did not become ready", + ) + .await?; + + fs::write(&first_release_file, b"release")?; + refresh_server.wait_until_request_started().await?; + + // The first child is now inside the provider request while retaining the credential lock. + // Releasing the second child must make its first try_lock call observe WouldBlock. Keep the + // provider response gated until that exact branch emits the contention marker. + fs::write(&second_release_file, b"release")?; + let contention_result = wait_for_marker( + &contention_file, + "second OAuth concurrency child did not observe refresh-lock contention", + ) + .await; + refresh_server.release_responses(); + + let (first_status, second_status) = tokio::try_join!(first_child.wait(), second_child.wait())?; + assert!( + first_status.success(), + "first OAuth concurrency child failed: {first_status}" + ); + assert!( + second_status.success(), + "second OAuth concurrency child failed: {second_status}" + ); + + server.verify().await; + contention_result?; + assert_eq!(refresh_server.request_count(), 1); + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn operation_preflight_refresh_failure_blocks_rmcp_request() -> anyhow::Result<()> { + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/.well-known/oauth-authorization-server/mcp")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "authorization_endpoint": format!("{}/oauth/authorize", server.uri()), + "token_endpoint": format!("{}/oauth/token", server.uri()), + "scopes_supported": [""], + }))) + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path("/oauth/token")) + .and(body_string_contains("grant_type=refresh_token")) + .and(body_string_contains(format!( + "refresh_token={REFRESH_TOKEN}" + ))) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "access_token": REFRESHED_ACCESS_TOKEN, + "token_type": "Bearer", + "expires_in": 31, + "refresh_token": ROTATED_REFRESH_TOKEN, + }))) + .expect(1) + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path("/oauth/token")) + .and(body_string_contains("grant_type=refresh_token")) + .and(body_string_contains(format!( + "refresh_token={ROTATED_REFRESH_TOKEN}" + ))) + .respond_with(ResponseTemplate::new(400).set_body_json(json!({ + "error": "invalid_grant", + "error_description": PREFLIGHT_REFRESH_ERROR, + }))) + .expect(1) + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path("/mcp")) + .and(header( + "authorization", + format!("Bearer {REFRESHED_ACCESS_TOKEN}"), + )) + .respond_with(|request: &Request| { + let body: Value = request.body_json().expect("valid JSON-RPC request"); + match body.get("method").and_then(Value::as_str) { + Some("initialize") => ResponseTemplate::new(200).set_body_json(json!({ + "jsonrpc": "2.0", + "id": body.get("id").cloned().unwrap_or(Value::Null), + "result": { + "protocolVersion": body + .pointer("/params/protocolVersion") + .cloned() + .unwrap_or_else(|| json!("2025-06-18")), + "capabilities": {}, + "serverInfo": { + "name": "oauth-preflight-test", + "version": "0.0.0-test", + }, + }, + })), + Some("notifications/initialized") => ResponseTemplate::new(202), + method => ResponseTemplate::new(400) + .set_body_string(format!("unexpected JSON-RPC method: {method:?}")), + } + }) + .expect(2) + .mount(&server) + .await; + + let codex_home = TempDir::new()?; + let server_url = format!("{}/mcp", server.uri()); + let status = Command::new(std::env::current_exe()?) + .args([ + "operation_preflight_refresh_failure_child", + "--exact", + "--ignored", + "--nocapture", + ]) + .env("CODEX_HOME", codex_home.path()) + .env(CHILD_SERVER_URL_ENV, &server_url) + .status() + .await?; + assert!( + status.success(), + "OAuth preflight failure child failed: {status}" + ); + + server.verify().await; + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn reports_auth_status_for_persisted_credentials() -> anyhow::Result<()> { let codex_home = TempDir::new()?; @@ -233,7 +638,7 @@ async fn auth_status(server_url: &str) -> anyhow::Result { } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] -#[ignore = "spawned by refreshes_expired_persisted_token_before_initialize"] +#[ignore = "spawned by startup_refresh_does_not_consume_handshake_timeout"] async fn oauth_startup_child() -> anyhow::Result<()> { let server_url = std::env::var(CHILD_SERVER_URL_ENV)?; @@ -276,6 +681,111 @@ async fn oauth_startup_child() -> anyhow::Result<()> { ) .await?; + initialize_client_with_timeout(&client, Duration::from_secs(1)).await?; + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +#[ignore = "spawned by operation_preflight_refresh_failure_blocks_rmcp_request"] +async fn operation_preflight_refresh_failure_child() -> anyhow::Result<()> { + let server_url = std::env::var(CHILD_SERVER_URL_ENV)?; + save_expired_file_mode_tokens(&server_url)?; + + let client = RmcpClient::new_streamable_http_client( + SERVER_NAME, + &server_url, + /*bearer_token*/ None, + /*http_headers*/ None, + /*env_http_headers*/ None, + OAuthCredentialsStoreMode::File, + AuthKeyringBackendKind::default(), + Environment::default_for_tests().get_http_client(), + /*auth_provider*/ None, + ) + .await?; + + initialize_client(&client).await?; + + tokio::time::sleep(Duration::from_millis(1_200)).await; + let error = client + .list_tools(/*params*/ None, Some(Duration::from_secs(5))) + .await + .expect_err("preflight refresh failure should abort the operation"); + let message = format!("{error:#}"); + assert!( + message.contains("failed to refresh OAuth tokens for server"), + "unexpected error: {message}" + ); + assert!( + message.contains(PREFLIGHT_REFRESH_ERROR), + "unexpected error: {message}" + ); + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +#[ignore = "spawned by concurrent_file_mode_startup_refreshes_once"] +async fn oauth_concurrency_seed_child() -> anyhow::Result<()> { + let server_url = std::env::var(CHILD_SERVER_URL_ENV)?; + save_expired_file_mode_tokens(&server_url)?; + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +#[ignore = "spawned by concurrent_file_mode_startup_refreshes_once"] +async fn oauth_concurrency_client_child() -> anyhow::Result<()> { + let server_url = std::env::var(CHILD_SERVER_URL_ENV)?; + let ready_file = PathBuf::from(std::env::var(CHILD_READY_FILE_ENV)?); + let release_file = PathBuf::from(std::env::var(CHILD_RELEASE_FILE_ENV)?); + if let Ok(marker_file) = std::env::var(CHILD_CONTENTION_FILE_ENV) { + tracing::subscriber::set_global_default(LockContentionMarkerSubscriber { + marker_file: PathBuf::from(marker_file), + }) + .map_err(|error| anyhow::anyhow!("failed to install contention subscriber: {error}"))?; + } + let client = RmcpClient::new_streamable_http_client( + SERVER_NAME, + &server_url, + /*bearer_token*/ None, + /*http_headers*/ None, + /*env_http_headers*/ None, + OAuthCredentialsStoreMode::File, + AuthKeyringBackendKind::default(), + Environment::default_for_tests().get_http_client(), + /*auth_provider*/ None, + ) + .await?; + + // Both processes must construct their OAuth client from the same expired snapshot before the + // parent releases either one. The parent then gates them separately to force lock contention. + fs::write(ready_file, b"ready")?; + while !release_file.exists() { + tokio::time::sleep(Duration::from_millis(/*millis*/ 10)).await; + } initialize_client(&client).await?; Ok(()) } + +fn save_expired_file_mode_tokens(server_url: &str) -> anyhow::Result<()> { + let mut response = OAuthTokenResponse::new( + AccessToken::new(EXPIRED_ACCESS_TOKEN.to_string()), + BasicTokenType::Bearer, + VendorExtraTokenFields::default(), + ); + response.set_refresh_token(Some(RefreshToken::new(REFRESH_TOKEN.to_string()))); + response.set_expires_in(Some(&Duration::from_secs(7200))); + let tokens = StoredOAuthTokens { + server_name: SERVER_NAME.to_string(), + url: server_url.to_string(), + client_id: "test-client-id".to_string(), + token_response: WrappedOAuthTokenResponse(response), + expires_at: Some(0), + }; + save_oauth_tokens( + SERVER_NAME, + &tokens, + OAuthCredentialsStoreMode::File, + AuthKeyringBackendKind::default(), + )?; + Ok(()) +} diff --git a/codex-rs/rmcp-client/tests/streamable_http_test_support.rs b/codex-rs/rmcp-client/tests/streamable_http_test_support.rs index 87a52f55dff6..532861de7ddf 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_test_support.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_test_support.rs @@ -106,10 +106,17 @@ pub(crate) async fn create_client_with_http_client( } pub(crate) async fn initialize_client(client: &RmcpClient) -> anyhow::Result<()> { + initialize_client_with_timeout(client, Duration::from_secs(5)).await +} + +pub(crate) async fn initialize_client_with_timeout( + client: &RmcpClient, + timeout: Duration, +) -> anyhow::Result<()> { client .initialize( init_params(), - Some(Duration::from_secs(5)), + Some(timeout), Box::new(|_, _| { async { Ok(ElicitationResponse {