From 2f818abb05acac5c9188cec000c6fa59e8d530ff Mon Sep 17 00:00:00 2001 From: Farhan Toddywala Date: Sat, 27 Jun 2026 02:20:41 +0000 Subject: [PATCH] Read faster model from safety buffering events --- codex-rs/app-server/tests/suite/v2/mod.rs | 1 + .../tests/suite/v2/safety_buffering.rs | 133 +++++++++++++++++ codex-rs/codex-api/src/common.rs | 11 +- .../src/endpoint/responses_websocket.rs | 139 ++++++++++++++++-- codex-rs/codex-api/src/sse/responses.rs | 10 +- 5 files changed, 277 insertions(+), 17 deletions(-) create mode 100644 codex-rs/app-server/tests/suite/v2/safety_buffering.rs diff --git a/codex-rs/app-server/tests/suite/v2/mod.rs b/codex-rs/app-server/tests/suite/v2/mod.rs index 9963d37be5d2..ffad02e9b4f8 100644 --- a/codex-rs/app-server/tests/suite/v2/mod.rs +++ b/codex-rs/app-server/tests/suite/v2/mod.rs @@ -55,6 +55,7 @@ mod request_permissions; mod request_user_input; mod request_validation; mod review; +mod safety_buffering; mod safety_check_downgrade; #[cfg(not(target_os = "windows"))] mod selected_capability_stack; diff --git a/codex-rs/app-server/tests/suite/v2/safety_buffering.rs b/codex-rs/app-server/tests/suite/v2/safety_buffering.rs new file mode 100644 index 000000000000..81d0fbd82121 --- /dev/null +++ b/codex-rs/app-server/tests/suite/v2/safety_buffering.rs @@ -0,0 +1,133 @@ +use anyhow::Result; +use app_test_support::TestAppServer; +use app_test_support::to_response; +use codex_app_server_protocol::JSONRPCResponse; +use codex_app_server_protocol::ModelSafetyBufferingUpdatedNotification; +use codex_app_server_protocol::RequestId; +use codex_app_server_protocol::ThreadStartParams; +use codex_app_server_protocol::ThreadStartResponse; +use codex_app_server_protocol::TurnStartParams; +use codex_app_server_protocol::TurnStartResponse; +use codex_app_server_protocol::UserInput; +use core_test_support::responses; +use core_test_support::skip_if_no_network; +use pretty_assertions::assert_eq; +use serde_json::json; +use std::path::Path; +use tempfile::TempDir; +use tokio::time::Duration; +use tokio::time::timeout; + +const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(20); +const FASTER_MODEL: &str = "gpt-fast-wire"; + +#[tokio::test] +async fn direct_websocket_safety_buffering_reaches_app_server_notification() -> Result<()> { + skip_if_no_network!(Ok(())); + + let mut buffering_event = responses::ev_response_created("resp-1"); + buffering_event["safety_buffering"] = json!({ + "use_cases": ["cyber"], + "reasons": ["user_risk"], + "faster_model": FASTER_MODEL, + }); + let websocket_server = responses::start_websocket_server(vec![vec![ + vec![ + json!({ + "type": "codex.response.metadata", + "headers": {"x-codex-safety-buffering-enabled": "false"}, + }), + responses::ev_response_created("warm-1"), + responses::ev_completed("warm-1"), + ], + vec![ + buffering_event, + responses::ev_assistant_message("msg-1", "Done"), + responses::ev_completed("resp-1"), + ], + ]]) + .await; + + let codex_home = TempDir::new()?; + create_websocket_config( + codex_home.path(), + &websocket_server.uri().replacen("ws://", "http://", 1), + )?; + let mut mcp = TestAppServer::new_with_auto_env(codex_home.path()).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??; + + let thread_request = mcp + .send_thread_start_request_with_auto_env(ThreadStartParams::default()) + .await?; + let thread_response: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(thread_request)), + ) + .await??; + let ThreadStartResponse { thread, .. } = to_response(thread_response)?; + + let turn_request = mcp + .send_turn_start_request(TurnStartParams { + thread_id: thread.id.clone(), + client_user_message_id: None, + input: vec![UserInput::Text { + text: "Check this request".to_string(), + text_elements: Vec::new(), + }], + ..Default::default() + }) + .await?; + let turn_response: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(turn_request)), + ) + .await??; + let TurnStartResponse { turn } = to_response(turn_response)?; + + let notification = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_notification_message("model/safetyBuffering/updated"), + ) + .await??; + let notification: ModelSafetyBufferingUpdatedNotification = + serde_json::from_value(notification.params.expect("notification params"))?; + + assert_eq!(notification.thread_id, thread.id); + assert_eq!(notification.turn_id, turn.id); + assert_eq!(notification.model, "mock-model"); + assert_eq!(notification.use_cases, ["cyber"]); + assert_eq!(notification.reasons, ["user_risk"]); + assert!(notification.show_buffering_ui); + assert_eq!(notification.faster_model.as_deref(), Some(FASTER_MODEL)); + + timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_notification_message("turn/completed"), + ) + .await??; + websocket_server.shutdown().await; + Ok(()) +} + +fn create_websocket_config(codex_home: &Path, server_uri: &str) -> std::io::Result<()> { + std::fs::write( + codex_home.join("config.toml"), + format!( + r#" +model = "mock-model" +approval_policy = "never" +sandbox_mode = "read-only" + +model_provider = "mock_provider" + +[model_providers.mock_provider] +name = "Mock provider for test" +base_url = "{server_uri}/v1" +wire_api = "responses" +request_max_retries = 0 +stream_max_retries = 0 +supports_websockets = true +"# + ), + ) +} diff --git a/codex-rs/codex-api/src/common.rs b/codex-rs/codex-api/src/common.rs index bd037deb5838..7f4b3974a69e 100644 --- a/codex-rs/codex-api/src/common.rs +++ b/codex-rs/codex-api/src/common.rs @@ -121,14 +121,17 @@ pub struct SafetyBuffering { pub reasons: Vec, #[serde(skip)] pub show_buffering_ui: bool, - #[serde(skip)] pub faster_model: Option, } impl SafetyBuffering { - pub(crate) fn with_treatment(mut self, treatment: &SafetyBufferingTreatment) -> Self { - self.show_buffering_ui = treatment.show_buffering_ui; - self.faster_model.clone_from(&treatment.faster_model); + pub(crate) fn with_treatment(mut self, treatment: Option<&SafetyBufferingTreatment>) -> Self { + if let Some(treatment) = treatment { + self.show_buffering_ui = treatment.show_buffering_ui; + self.faster_model.clone_from(&treatment.faster_model); + } else { + self.show_buffering_ui = true; + } self } } diff --git a/codex-rs/codex-api/src/endpoint/responses_websocket.rs b/codex-rs/codex-api/src/endpoint/responses_websocket.rs index 755228a9f59a..6c4e5bb6f31b 100644 --- a/codex-rs/codex-api/src/endpoint/responses_websocket.rs +++ b/codex-rs/codex-api/src/endpoint/responses_websocket.rs @@ -636,7 +636,7 @@ async fn run_websocket_response_stream( turn_state: Option<&OnceLock>, ) -> Result<(), ApiError> { let mut last_server_model: Option = None; - let mut safety_buffering_treatment = SafetyBufferingTreatment::default(); + let mut safety_buffering_treatment: Option = None; send_websocket_request( ws_stream, request_text, @@ -690,17 +690,10 @@ async fn run_websocket_response_stream( { let _ = turn_state.set(response_turn_state); } - if let Some(headers) = event.headers.as_ref().and_then(Value::as_object) - && let Some(treatment) = - treatment_from_headers(&json_headers_to_http_headers(headers)) - { - safety_buffering_treatment = treatment; - } let model_verifications = event.model_verifications(); let turn_moderation_metadata = event.turn_moderation_metadata(); - let safety_buffering = event - .safety_buffering() - .map(|buffering| buffering.with_treatment(&safety_buffering_treatment)); + let safety_buffering = + safety_buffering_for_event(&event, &mut safety_buffering_treatment); if event.kind() == "codex.rate_limits" { if let Some(snapshot) = parse_rate_limit_event(&text) { let _ = tx_event.send(Ok(ResponseEvent::RateLimits(snapshot))).await; @@ -775,6 +768,21 @@ async fn run_websocket_response_stream( Ok(()) } +fn safety_buffering_for_event( + event: &ResponsesStreamEvent, + treatment: &mut Option, +) -> Option { + if let Some(headers) = event.headers.as_ref().and_then(Value::as_object) + && let Some(updated_treatment) = + treatment_from_headers(&json_headers_to_http_headers(headers)) + { + *treatment = Some(updated_treatment); + } + event + .safety_buffering() + .map(|buffering| buffering.with_treatment(treatment.as_ref())) +} + async fn send_websocket_request( ws_stream: &WsStream, request_text: String, @@ -1042,4 +1050,115 @@ mod tests { Some(&HeaderValue::from_static("default-only")) ); } + + #[test] + fn direct_websocket_safety_buffering_uses_wire_faster_model() { + let event: ResponsesStreamEvent = serde_json::from_value(json!({ + "type": "response.output_text.delta", + "safety_buffering": { + "use_cases": ["cyber"], + "reasons": ["user_risk"], + "faster_model": "gpt-fast-wire" + } + })) + .expect("deserialize safety buffering event"); + let mut treatment = None; + + let buffering = safety_buffering_for_event(&event, &mut treatment) + .expect("expected safety buffering payload"); + + assert!(buffering.show_buffering_ui); + assert_eq!(buffering.faster_model.as_deref(), Some("gpt-fast-wire")); + } + + #[test] + fn websocket_safety_buffering_treatment_metadata_overrides_wire_values() { + let metadata: ResponsesStreamEvent = serde_json::from_value(json!({ + "type": "codex.response.metadata", + "headers": { + "x-codex-safety-buffering-enabled": "true", + "x-codex-safety-buffering-faster-model": "gpt-fast-header" + } + })) + .expect("deserialize treatment metadata"); + let event: ResponsesStreamEvent = serde_json::from_value(json!({ + "type": "response.output_text.delta", + "safety_buffering": { + "use_cases": ["cyber"], + "reasons": ["user_risk"], + "faster_model": "gpt-fast-wire" + } + })) + .expect("deserialize safety buffering event"); + let mut treatment = None; + + assert!(safety_buffering_for_event(&metadata, &mut treatment).is_none()); + let buffering = safety_buffering_for_event(&event, &mut treatment) + .expect("expected safety buffering payload"); + + assert!(buffering.show_buffering_ui); + assert_eq!(buffering.faster_model.as_deref(), Some("gpt-fast-header")); + } + + #[test] + fn websocket_safety_buffering_treatment_can_clear_wire_values() { + for (headers, expected_show_buffering_ui) in [ + (json!({"x-codex-safety-buffering-enabled": "true"}), true), + ( + json!({ + "x-codex-safety-buffering-enabled": "false", + "x-codex-safety-buffering-faster-model": "ignored-header-model" + }), + false, + ), + ] { + let metadata: ResponsesStreamEvent = serde_json::from_value(json!({ + "type": "codex.response.metadata", + "headers": headers + })) + .expect("deserialize treatment metadata"); + let event: ResponsesStreamEvent = serde_json::from_value(json!({ + "type": "response.output_text.delta", + "safety_buffering": { + "use_cases": ["cyber"], + "reasons": ["user_risk"], + "faster_model": "gpt-fast-wire" + } + })) + .expect("deserialize safety buffering event"); + let mut treatment = None; + + assert!(safety_buffering_for_event(&metadata, &mut treatment).is_none()); + let buffering = safety_buffering_for_event(&event, &mut treatment) + .expect("expected safety buffering payload"); + + assert_eq!(buffering.show_buffering_ui, expected_show_buffering_ui); + assert_eq!(buffering.faster_model, None); + } + } + + #[test] + fn direct_websocket_safety_buffering_accepts_missing_and_null_faster_model() { + for faster_model in [None, Some(Value::Null)] { + let mut payload = json!({ + "type": "response.output_text.delta", + "safety_buffering": { + "use_cases": ["cyber"], + "reasons": ["user_risk"] + } + }); + if let Some(faster_model) = faster_model { + payload["safety_buffering"]["faster_model"] = faster_model; + } + let event: ResponsesStreamEvent = + serde_json::from_value(payload).expect("deserialize safety buffering event"); + let mut treatment = None; + + let buffering = safety_buffering_for_event(&event, &mut treatment) + .expect("expected safety buffering payload"); + + assert!(buffering.show_buffering_ui); + assert_eq!(buffering.faster_model, None); + } + } } diff --git a/codex-rs/codex-api/src/sse/responses.rs b/codex-rs/codex-api/src/sse/responses.rs index 71f73c0232cc..b0fc4cc5a19c 100644 --- a/codex-rs/codex-api/src/sse/responses.rs +++ b/codex-rs/codex-api/src/sse/responses.rs @@ -517,7 +517,7 @@ async fn process_sse_with_treatment( let turn_moderation_metadata = event.turn_moderation_metadata(); let safety_buffering = event .safety_buffering() - .map(|buffering| buffering.with_treatment(&safety_buffering_treatment)); + .map(|buffering| buffering.with_treatment(Some(&safety_buffering_treatment))); if let Some(model) = event.response_model() && last_server_model.as_deref() != Some(model.as_str()) @@ -1354,7 +1354,8 @@ mod tests { "delta": "hello", "safety_buffering": { "use_cases": ["cyber"], - "reasons": ["user_risk"] + "reasons": ["user_risk"], + "faster_model": "unexpected-wire-model" } }), json!({ @@ -1381,7 +1382,10 @@ mod tests { assert_matches!( &events[1], ResponseEvent::SafetyBuffering(buffering) - if buffering.use_cases == ["cyber"] && buffering.reasons == ["user_risk"] + if buffering.use_cases == ["cyber"] + && buffering.reasons == ["user_risk"] + && !buffering.show_buffering_ui + && buffering.faster_model.is_none() ); assert_matches!(&events[2], ResponseEvent::OutputTextDelta(delta) if delta == "hello"); assert_matches!(