Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions codex-rs/app-server/tests/suite/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ mod request_permissions;
mod request_user_input;
mod request_validation;
mod review;
mod safety_buffering;
mod safety_check_downgrade;
#[cfg(not(target_os = "windows"))]
mod selected_capability_stack;
Expand Down
133 changes: 133 additions & 0 deletions codex-rs/app-server/tests/suite/v2/safety_buffering.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
use anyhow::Result;
use app_test_support::TestAppServer;
use app_test_support::to_response;
use codex_app_server_protocol::JSONRPCResponse;
use codex_app_server_protocol::ModelSafetyBufferingUpdatedNotification;
use codex_app_server_protocol::RequestId;
use codex_app_server_protocol::ThreadStartParams;
use codex_app_server_protocol::ThreadStartResponse;
use codex_app_server_protocol::TurnStartParams;
use codex_app_server_protocol::TurnStartResponse;
use codex_app_server_protocol::UserInput;
use core_test_support::responses;
use core_test_support::skip_if_no_network;
use pretty_assertions::assert_eq;
use serde_json::json;
use std::path::Path;
use tempfile::TempDir;
use tokio::time::Duration;
use tokio::time::timeout;

const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(20);
const FASTER_MODEL: &str = "gpt-fast-wire";

#[tokio::test]
async fn direct_websocket_safety_buffering_reaches_app_server_notification() -> Result<()> {
skip_if_no_network!(Ok(()));

let mut buffering_event = responses::ev_response_created("resp-1");
buffering_event["safety_buffering"] = json!({
"use_cases": ["cyber"],
"reasons": ["user_risk"],
"faster_model": FASTER_MODEL,
});
let websocket_server = responses::start_websocket_server(vec![vec![
vec![
json!({
"type": "codex.response.metadata",
"headers": {"x-codex-safety-buffering-enabled": "false"},
}),
responses::ev_response_created("warm-1"),
responses::ev_completed("warm-1"),
],
vec![
buffering_event,
responses::ev_assistant_message("msg-1", "Done"),
responses::ev_completed("resp-1"),
],
]])
.await;

let codex_home = TempDir::new()?;
create_websocket_config(
codex_home.path(),
&websocket_server.uri().replacen("ws://", "http://", 1),
)?;
let mut mcp = TestAppServer::new_with_auto_env(codex_home.path()).await?;
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;

let thread_request = mcp
.send_thread_start_request_with_auto_env(ThreadStartParams::default())
.await?;
let thread_response: JSONRPCResponse = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_response_message(RequestId::Integer(thread_request)),
)
.await??;
let ThreadStartResponse { thread, .. } = to_response(thread_response)?;

let turn_request = mcp
.send_turn_start_request(TurnStartParams {
thread_id: thread.id.clone(),
client_user_message_id: None,
input: vec![UserInput::Text {
text: "Check this request".to_string(),
text_elements: Vec::new(),
}],
..Default::default()
})
.await?;
let turn_response: JSONRPCResponse = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_response_message(RequestId::Integer(turn_request)),
)
.await??;
let TurnStartResponse { turn } = to_response(turn_response)?;

let notification = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_notification_message("model/safetyBuffering/updated"),
)
.await??;
let notification: ModelSafetyBufferingUpdatedNotification =
serde_json::from_value(notification.params.expect("notification params"))?;

assert_eq!(notification.thread_id, thread.id);
assert_eq!(notification.turn_id, turn.id);
assert_eq!(notification.model, "mock-model");
assert_eq!(notification.use_cases, ["cyber"]);
assert_eq!(notification.reasons, ["user_risk"]);
assert!(notification.show_buffering_ui);
assert_eq!(notification.faster_model.as_deref(), Some(FASTER_MODEL));

timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_notification_message("turn/completed"),
)
.await??;
websocket_server.shutdown().await;
Ok(())
}

fn create_websocket_config(codex_home: &Path, server_uri: &str) -> std::io::Result<()> {
std::fs::write(
codex_home.join("config.toml"),
format!(
r#"
model = "mock-model"
approval_policy = "never"
sandbox_mode = "read-only"

model_provider = "mock_provider"

[model_providers.mock_provider]
name = "Mock provider for test"
base_url = "{server_uri}/v1"
wire_api = "responses"
request_max_retries = 0
stream_max_retries = 0
supports_websockets = true
"#
),
)
}
11 changes: 7 additions & 4 deletions codex-rs/codex-api/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,17 @@ pub struct SafetyBuffering {
pub reasons: Vec<String>,
#[serde(skip)]
pub show_buffering_ui: bool,
#[serde(skip)]
pub faster_model: Option<String>,
}

impl SafetyBuffering {
pub(crate) fn with_treatment(mut self, treatment: &SafetyBufferingTreatment) -> Self {
self.show_buffering_ui = treatment.show_buffering_ui;
self.faster_model.clone_from(&treatment.faster_model);
pub(crate) fn with_treatment(mut self, treatment: Option<&SafetyBufferingTreatment>) -> Self {
if let Some(treatment) = treatment {
self.show_buffering_ui = treatment.show_buffering_ui;
self.faster_model.clone_from(&treatment.faster_model);
} else {
self.show_buffering_ui = true;
}
self
}
}
Expand Down
139 changes: 129 additions & 10 deletions codex-rs/codex-api/src/endpoint/responses_websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ async fn run_websocket_response_stream(
turn_state: Option<&OnceLock<String>>,
) -> Result<(), ApiError> {
let mut last_server_model: Option<String> = None;
let mut safety_buffering_treatment = SafetyBufferingTreatment::default();
let mut safety_buffering_treatment: Option<SafetyBufferingTreatment> = None;
send_websocket_request(
ws_stream,
request_text,
Expand Down Expand Up @@ -690,17 +690,10 @@ async fn run_websocket_response_stream(
{
let _ = turn_state.set(response_turn_state);
}
if let Some(headers) = event.headers.as_ref().and_then(Value::as_object)
&& let Some(treatment) =
treatment_from_headers(&json_headers_to_http_headers(headers))
{
safety_buffering_treatment = treatment;
}
let model_verifications = event.model_verifications();
let turn_moderation_metadata = event.turn_moderation_metadata();
let safety_buffering = event
.safety_buffering()
.map(|buffering| buffering.with_treatment(&safety_buffering_treatment));
let safety_buffering =
safety_buffering_for_event(&event, &mut safety_buffering_treatment);
if event.kind() == "codex.rate_limits" {
if let Some(snapshot) = parse_rate_limit_event(&text) {
let _ = tx_event.send(Ok(ResponseEvent::RateLimits(snapshot))).await;
Expand Down Expand Up @@ -775,6 +768,21 @@ async fn run_websocket_response_stream(
Ok(())
}

fn safety_buffering_for_event(
event: &ResponsesStreamEvent,
treatment: &mut Option<SafetyBufferingTreatment>,
) -> Option<crate::common::SafetyBuffering> {
if let Some(headers) = event.headers.as_ref().and_then(Value::as_object)
&& let Some(updated_treatment) =
treatment_from_headers(&json_headers_to_http_headers(headers))
{
*treatment = Some(updated_treatment);
}
event
.safety_buffering()
.map(|buffering| buffering.with_treatment(treatment.as_ref()))
}

async fn send_websocket_request(
ws_stream: &WsStream,
request_text: String,
Expand Down Expand Up @@ -1042,4 +1050,115 @@ mod tests {
Some(&HeaderValue::from_static("default-only"))
);
}

#[test]
fn direct_websocket_safety_buffering_uses_wire_faster_model() {
let event: ResponsesStreamEvent = serde_json::from_value(json!({
"type": "response.output_text.delta",
"safety_buffering": {
"use_cases": ["cyber"],
"reasons": ["user_risk"],
"faster_model": "gpt-fast-wire"
}
}))
.expect("deserialize safety buffering event");
let mut treatment = None;

let buffering = safety_buffering_for_event(&event, &mut treatment)
.expect("expected safety buffering payload");

assert!(buffering.show_buffering_ui);
assert_eq!(buffering.faster_model.as_deref(), Some("gpt-fast-wire"));
}

#[test]
fn websocket_safety_buffering_treatment_metadata_overrides_wire_values() {
let metadata: ResponsesStreamEvent = serde_json::from_value(json!({
"type": "codex.response.metadata",
"headers": {
"x-codex-safety-buffering-enabled": "true",
"x-codex-safety-buffering-faster-model": "gpt-fast-header"
}
}))
.expect("deserialize treatment metadata");
let event: ResponsesStreamEvent = serde_json::from_value(json!({
"type": "response.output_text.delta",
"safety_buffering": {
"use_cases": ["cyber"],
"reasons": ["user_risk"],
"faster_model": "gpt-fast-wire"
}
}))
.expect("deserialize safety buffering event");
let mut treatment = None;

assert!(safety_buffering_for_event(&metadata, &mut treatment).is_none());
let buffering = safety_buffering_for_event(&event, &mut treatment)
.expect("expected safety buffering payload");

assert!(buffering.show_buffering_ui);
assert_eq!(buffering.faster_model.as_deref(), Some("gpt-fast-header"));
}

#[test]
fn websocket_safety_buffering_treatment_can_clear_wire_values() {
for (headers, expected_show_buffering_ui) in [
(json!({"x-codex-safety-buffering-enabled": "true"}), true),
(
json!({
"x-codex-safety-buffering-enabled": "false",
"x-codex-safety-buffering-faster-model": "ignored-header-model"
}),
false,
),
] {
let metadata: ResponsesStreamEvent = serde_json::from_value(json!({
"type": "codex.response.metadata",
"headers": headers
}))
.expect("deserialize treatment metadata");
let event: ResponsesStreamEvent = serde_json::from_value(json!({
"type": "response.output_text.delta",
"safety_buffering": {
"use_cases": ["cyber"],
"reasons": ["user_risk"],
"faster_model": "gpt-fast-wire"
}
}))
.expect("deserialize safety buffering event");
let mut treatment = None;

assert!(safety_buffering_for_event(&metadata, &mut treatment).is_none());
let buffering = safety_buffering_for_event(&event, &mut treatment)
.expect("expected safety buffering payload");

assert_eq!(buffering.show_buffering_ui, expected_show_buffering_ui);
assert_eq!(buffering.faster_model, None);
}
}

#[test]
fn direct_websocket_safety_buffering_accepts_missing_and_null_faster_model() {
for faster_model in [None, Some(Value::Null)] {
let mut payload = json!({
"type": "response.output_text.delta",
"safety_buffering": {
"use_cases": ["cyber"],
"reasons": ["user_risk"]
}
});
if let Some(faster_model) = faster_model {
payload["safety_buffering"]["faster_model"] = faster_model;
}
let event: ResponsesStreamEvent =
serde_json::from_value(payload).expect("deserialize safety buffering event");
let mut treatment = None;

let buffering = safety_buffering_for_event(&event, &mut treatment)
.expect("expected safety buffering payload");

assert!(buffering.show_buffering_ui);
assert_eq!(buffering.faster_model, None);
}
}
}
10 changes: 7 additions & 3 deletions codex-rs/codex-api/src/sse/responses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ async fn process_sse_with_treatment(
let turn_moderation_metadata = event.turn_moderation_metadata();
let safety_buffering = event
.safety_buffering()
.map(|buffering| buffering.with_treatment(&safety_buffering_treatment));
.map(|buffering| buffering.with_treatment(Some(&safety_buffering_treatment)));

if let Some(model) = event.response_model()
&& last_server_model.as_deref() != Some(model.as_str())
Expand Down Expand Up @@ -1354,7 +1354,8 @@ mod tests {
"delta": "hello",
"safety_buffering": {
"use_cases": ["cyber"],
"reasons": ["user_risk"]
"reasons": ["user_risk"],
"faster_model": "unexpected-wire-model"
}
}),
json!({
Expand All @@ -1381,7 +1382,10 @@ mod tests {
assert_matches!(
&events[1],
ResponseEvent::SafetyBuffering(buffering)
if buffering.use_cases == ["cyber"] && buffering.reasons == ["user_risk"]
if buffering.use_cases == ["cyber"]
&& buffering.reasons == ["user_risk"]
&& !buffering.show_buffering_ui
&& buffering.faster_model.is_none()
);
assert_matches!(&events[2], ResponseEvent::OutputTextDelta(delta) if delta == "hello");
assert_matches!(
Expand Down
Loading