Skip to content
Open
51 changes: 48 additions & 3 deletions codex-rs/rmcp-client/src/http_client_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use codex_exec_server::HttpResponseBodyStream;
use futures::StreamExt;
use futures::stream;
use futures::stream::BoxStream;
use oauth2::AccessToken;
use reqwest::StatusCode;
use reqwest::header::ACCEPT;
use reqwest::header::AUTHORIZATION;
Expand Down Expand Up @@ -55,12 +56,17 @@ pub(crate) struct StreamableHttpClientAdapter {
http_client: Arc<dyn HttpClient>,
default_headers: HeaderMap,
auth_provider: Option<SharedAuthProvider>,
attribute_rejected_access_token: bool,
}

#[derive(Debug, thiserror::Error)]
pub(crate) enum StreamableHttpClientAdapterError {
#[error("streamable HTTP session expired with 404 Not Found")]
SessionExpired404,
#[error("MCP server rejected the access token with HTTP 401 Unauthorized")]
AccessTokenRejected { rejected_access_token: AccessToken },
#[error("MCP OAuth operation failed: {0:#}")]
OAuth(#[source] anyhow::Error),
#[error(transparent)]
HttpRequest(#[from] ExecServerError),
#[error("invalid HTTP header: {0}")]
Expand All @@ -77,8 +83,15 @@ impl StreamableHttpClientAdapter {
http_client,
default_headers,
auth_provider,
attribute_rejected_access_token: false,
}
}

/// Preserves the access token associated with a 401 for Codex-owned OAuth recovery.
pub(crate) fn with_rejected_token_attribution(mut self) -> Self {
self.attribute_rejected_access_token = true;
self
}
}

impl StreamableHttpClient for StreamableHttpClientAdapter {
Expand Down Expand Up @@ -109,7 +122,7 @@ impl StreamableHttpClient for StreamableHttpClientAdapter {
JSON_MIME_TYPE.to_string(),
StreamableHttpClientAdapterError::Header,
)?;
if let Some(auth_token) = auth_token {
if let Some(auth_token) = auth_token.as_deref() {
insert_header(
&mut headers,
AUTHORIZATION,
Expand Down Expand Up @@ -162,6 +175,12 @@ impl StreamableHttpClient for StreamableHttpClientAdapter {
StreamableHttpClientAdapterError::SessionExpired404,
));
}
if self.attribute_rejected_access_token
&& response.status == StatusCode::UNAUTHORIZED.as_u16()
&& let Some(error) = access_token_rejected(auth_token.as_deref())
{
return Err(error);
}
if response.status == StatusCode::UNAUTHORIZED.as_u16()
&& let Some(header) =
response_header(&response.headers, reqwest::header::WWW_AUTHENTICATE)
Expand Down Expand Up @@ -240,7 +259,7 @@ impl StreamableHttpClient for StreamableHttpClientAdapter {
let mut headers = self.default_headers.clone();
headers.extend(custom_headers);
self.add_auth_headers(&mut headers);
if let Some(auth_token) = auth_token {
if let Some(auth_token) = auth_token.as_deref() {
insert_header(
&mut headers,
AUTHORIZATION,
Expand Down Expand Up @@ -274,6 +293,12 @@ impl StreamableHttpClient for StreamableHttpClientAdapter {
if response.status == StatusCode::METHOD_NOT_ALLOWED.as_u16() {
return Ok(());
}
if self.attribute_rejected_access_token
&& response.status == StatusCode::UNAUTHORIZED.as_u16()
&& let Some(error) = access_token_rejected(auth_token.as_deref())
{
return Err(error);
}
if !status_is_success(response.status) {
return Err(StreamableHttpError::UnexpectedServerResponse(
format!("DELETE returned HTTP {}", response.status).into(),
Expand Down Expand Up @@ -316,7 +341,7 @@ impl StreamableHttpClient for StreamableHttpClientAdapter {
StreamableHttpClientAdapterError::Header,
)?;
}
if let Some(auth_token) = auth_token {
if let Some(auth_token) = auth_token.as_deref() {
insert_header(
&mut headers,
AUTHORIZATION,
Expand Down Expand Up @@ -349,6 +374,12 @@ impl StreamableHttpClient for StreamableHttpClientAdapter {
StreamableHttpClientAdapterError::SessionExpired404,
));
}
if self.attribute_rejected_access_token
&& response.status == StatusCode::UNAUTHORIZED.as_u16()
&& let Some(error) = access_token_rejected(auth_token.as_deref())
{
return Err(error);
}
if !status_is_success(response.status) {
return Err(StreamableHttpError::UnexpectedServerResponse(
format!("GET returned HTTP {}", response.status).into(),
Expand All @@ -371,6 +402,20 @@ impl StreamableHttpClient for StreamableHttpClientAdapter {
}
}

fn access_token_rejected(
auth_token: Option<&str>,
) -> Option<StreamableHttpError<StreamableHttpClientAdapterError>> {
// Preserve the token associated with this response. Reading the current credential after a
// delayed 401 is racy: another concurrent request may already have refreshed A to B, in which
// case recovery must retry B rather than refresh B a second time. AccessToken's Debug
// implementation redacts the secret if this error is logged.
auth_token.map(|rejected_access_token| {
StreamableHttpError::Client(StreamableHttpClientAdapterError::AccessTokenRejected {
rejected_access_token: AccessToken::new(rejected_access_token.to_string()),
})
})
}

impl StreamableHttpClientAdapter {
fn add_auth_headers(&self, headers: &mut HeaderMap) {
if let Some(auth_provider) = &self.auth_provider {
Expand Down
1 change: 1 addition & 0 deletions codex-rs/rmcp-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod in_process_transport;
mod logging_client_handler;
mod oauth;
mod oauth_http_client;
mod oauth_transport;
mod perform_oauth_login;
mod program_resolver;
mod rmcp_client;
Expand Down
82 changes: 1 addition & 81 deletions codex-rs/rmcp-client/src/oauth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ use tokio::sync::Mutex;

use codex_utils_home_dir::find_codex_home;

pub(crate) use self::refresh_transaction::request_oauth_token_response;
pub(crate) use self::resolved_store::ResolvedOAuthCredentialStore;
pub(crate) use self::resolved_store::ResolvedOAuthTokens;
pub(crate) use self::resolved_store::load_oauth_tokens_from_store;
Expand Down Expand Up @@ -489,67 +490,6 @@ impl OAuthPersistor {
}),
}
}

/// Persists RMCP-managed credential changes back to this client's resolved authority.
#[expect(
clippy::await_holding_invalid_type,
reason = "AuthorizationManager async access must be serialized through its mutex"
)]
pub(crate) async fn persist_if_needed(&self) -> Result<()> {
let (client_id, maybe_credentials) = {
let manager = self.inner.authorization_manager.clone();
let guard = manager.lock().await;
guard.get_credentials().await
}?;

match maybe_credentials {
Some(credentials) => {
let mut last_credentials = self.inner.last_credentials.lock().await;
let new_token_response = WrappedOAuthTokenResponse(credentials.clone());
let same_token = last_credentials
.as_ref()
.map(|previous| previous.token_response == new_token_response)
.unwrap_or(false);
let expires_at = if same_token {
last_credentials
.as_ref()
.and_then(|previous| previous.expires_at)
} else {
compute_expires_at_millis(&credentials)
};
let stored = StoredOAuthTokens {
server_name: self.inner.server_name.clone(),
url: self.inner.url.clone(),
client_id,
token_response: new_token_response,
expires_at,
};
if last_credentials.as_ref() != Some(&stored) {
save_to_resolved_store(&DefaultKeyringStore, &self.inner, &stored)?;
*last_credentials = Some(stored);
}
}
None => {
let mut last_credentials = self.inner.last_credentials.lock().await;
if last_credentials.take().is_some()
&& let Err(error) = delete_from_resolved_store(
&DefaultKeyringStore,
&self.inner.server_name,
&self.inner.url,
self.inner.credential_store,
)
{
warn!(
server_name = %self.inner.server_name,
error = %error,
"failed to remove MCP OAuth credentials from the resolved store"
);
}
}
}

Ok(())
}
}

fn save_to_resolved_store<K: KeyringStore + Clone + 'static>(
Expand All @@ -570,26 +510,6 @@ fn save_to_resolved_store<K: KeyringStore + Clone + 'static>(
}
}

fn delete_from_resolved_store<K: KeyringStore + Clone + 'static>(
keyring_store: &K,
server_name: &str,
url: &str,
credential_store: ResolvedOAuthCredentialStore,
) -> Result<bool> {
match credential_store {
ResolvedOAuthCredentialStore::File => {
let key = compute_store_key(server_name, url)?;
delete_oauth_tokens_from_file(&key)
}
ResolvedOAuthCredentialStore::Keyring(AuthKeyringBackendKind::Direct) => {
delete_oauth_tokens_from_direct_keyring(keyring_store, server_name, url)
}
ResolvedOAuthCredentialStore::Keyring(AuthKeyringBackendKind::Secrets) => {
delete_oauth_tokens_from_secrets_keyring(keyring_store, server_name, url)
}
}
}

const FALLBACK_FILENAME: &str = ".credentials.json";
const MCP_SERVER_TYPE: &str = "http";

Expand Down
Loading
Loading