From a3b186169a92d5ac452e195e421c8143ba6d2d8d Mon Sep 17 00:00:00 2001 From: Matthew Gumport Date: Mon, 16 Mar 2026 16:50:04 -0700 Subject: [PATCH 01/52] expose content_type on multirangeinfo --- .bleep | 2 +- pingora-proxy/src/proxy_cache.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.bleep b/.bleep index 64f07f84..0b0eaf67 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -f0b43320bb1a5f7788a7d0e90a804e045f0af2fb +b94c8d2ff134d87e6980f671954170437a673ddb \ No newline at end of file diff --git a/pingora-proxy/src/proxy_cache.rs b/pingora-proxy/src/proxy_cache.rs index 43b2ace9..748de963 100644 --- a/pingora-proxy/src/proxy_cache.rs +++ b/pingora-proxy/src/proxy_cache.rs @@ -1276,7 +1276,7 @@ pub mod range_filter { pub ranges: Vec>, pub boundary: String, total_length: usize, - content_type: Option, + pub content_type: Option, } impl MultiRangeInfo { From 21fa59297c36bb5f7a2eb1249ba66dd082e9be1b Mon Sep 17 00:00:00 2001 From: Edward Wang Date: Fri, 6 Mar 2026 13:37:11 -0800 Subject: [PATCH 02/52] Allow adjusting upstream modules on response header recv Adds an async filter (feature-gated) to adjust upstream modules prior to those modules (currently just compression) running. --- .bleep | 2 +- docs/user_guide/phase.md | 8 +++++++- docs/user_guide/phase_chart.md | 3 ++- pingora-proxy/Cargo.toml | 1 + pingora-proxy/src/proxy_custom.rs | 6 ++++++ pingora-proxy/src/proxy_h1.rs | 6 ++++++ pingora-proxy/src/proxy_h2.rs | 6 ++++++ pingora-proxy/src/proxy_trait.rs | 33 +++++++++++++++++++++++++++++++ pingora/Cargo.toml | 6 ++++++ 9 files changed, 68 insertions(+), 3 deletions(-) diff --git a/.bleep b/.bleep index 0b0eaf67..d4078332 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -b94c8d2ff134d87e6980f671954170437a673ddb \ No newline at end of file +e192cb4b06a03938fa92528156c6fbc681f42766 \ No newline at end of file diff --git a/docs/user_guide/phase.md b/docs/user_guide/phase.md index 3c80f913..5f3a891a 100644 --- a/docs/user_guide/phase.md +++ b/docs/user_guide/phase.md @@ -29,7 +29,8 @@ Pingora-proxy allows users to insert arbitrary logic into the life of a request. upstream_request_filter --> request_body_filter; request_body_filter --> SendReq{{IO: send request to upstream}}; SendReq-->RecvResp{{IO: read response from upstream}}; - RecvResp-->upstream_response_filter-->response_filter-->upstream_response_body_filter-->response_body_filter-->logging-->endreq("request done"); + RecvResp-.feature: adjust_upstream_modules.->adjust_upstream_modules; + adjust_upstream_modules-->upstream_response_filter-->response_filter-->upstream_response_body_filter-->response_body_filter-->logging-->endreq("request done"); fail_to_connect --can retry-->upstream_peer; fail_to_connect --can't retry-->fail_to_proxy--send error response-->logging; @@ -92,6 +93,11 @@ If the error is not retry-able, the request will end. ### `upstream_request_filter()` This phase is to modify requests before sending to upstream. +### `adjust_upstream_modules()` _(feature: `adjust_upstream_modules`)_ +This phase is triggered when the upstream response header arrives, before upstream modules (such as `upstream_compression`) process it. + +Use this to configure upstream module behavior based on the response header, e.g. setting a dictionary for dictionary-based content encoding. The response header is provided as an immutable reference; to modify the response header itself, use `upstream_response_filter()` instead. + ### `upstream_response_filter()/upstream_response_body_filter()/upstream_response_trailer_filter()` This phase is triggered after an upstream response header/body/trailer is received. diff --git a/docs/user_guide/phase_chart.md b/docs/user_guide/phase_chart.md index 94988724..b915f950 100644 --- a/docs/user_guide/phase_chart.md +++ b/docs/user_guide/phase_chart.md @@ -14,7 +14,8 @@ Pingora proxy phases without caching upstream_request_filter --> request_body_filter; request_body_filter --> SendReq{{IO: send request to upstream}}; SendReq-->RecvResp{{IO: read response from upstream}}; - RecvResp-->upstream_response_filter-->response_filter-->upstream_response_body_filter-->response_body_filter-->logging-->endreq("request done"); + RecvResp-.feature: adjust_upstream_modules.->adjust_upstream_modules; + adjust_upstream_modules-->upstream_response_filter-->response_filter-->upstream_response_body_filter-->response_body_filter-->logging-->endreq("request done"); fail_to_connect --can retry-->upstream_peer; fail_to_connect --can't retry-->fail_to_proxy--send error response-->logging; diff --git a/pingora-proxy/Cargo.toml b/pingora-proxy/Cargo.toml index 1f367d89..27d98ddc 100644 --- a/pingora-proxy/Cargo.toml +++ b/pingora-proxy/Cargo.toml @@ -69,6 +69,7 @@ s2n = ["pingora-core/s2n", "pingora-cache/s2n", "any_tls"] openssl_derived = ["any_tls"] any_tls = [] sentry = ["pingora-core/sentry"] +adjust_upstream_modules = [] connection_filter = ["pingora-core/connection_filter"] prometheus = ["pingora-core/prometheus"] diff --git a/pingora-proxy/src/proxy_custom.rs b/pingora-proxy/src/proxy_custom.rs index 63079111..b571b3ce 100644 --- a/pingora-proxy/src/proxy_custom.rs +++ b/pingora-proxy/src/proxy_custom.rs @@ -386,6 +386,12 @@ where // skip downstream filtering entirely as the 304 will not be sent break; } + #[cfg(feature = "adjust_upstream_modules")] + if let HttpTask::Header(header, end_of_stream) = &t { + self.inner + .adjust_upstream_modules(session, header, *end_of_stream, ctx) + .await?; + } session.upstream_compression.response_filter(&mut t); // check error and abort // otherwise the error is surfaced via write_response_tasks() diff --git a/pingora-proxy/src/proxy_h1.rs b/pingora-proxy/src/proxy_h1.rs index 9f04289c..9f498aa0 100644 --- a/pingora-proxy/src/proxy_h1.rs +++ b/pingora-proxy/src/proxy_h1.rs @@ -460,6 +460,12 @@ where // skip downstream filtering entirely as the 304 will not be sent break; } + #[cfg(feature = "adjust_upstream_modules")] + if let HttpTask::Header(header, end_of_stream) = &t { + self.inner + .adjust_upstream_modules(session, header, *end_of_stream, ctx) + .await?; + } session.upstream_compression.response_filter(&mut t); let task = self.h1_response_filter(session, t, ctx, &mut serve_from_cache, diff --git a/pingora-proxy/src/proxy_h2.rs b/pingora-proxy/src/proxy_h2.rs index 0d633e4a..acf61f07 100644 --- a/pingora-proxy/src/proxy_h2.rs +++ b/pingora-proxy/src/proxy_h2.rs @@ -416,6 +416,12 @@ where // skip downstream filtering entirely as the 304 will not be sent break; } + #[cfg(feature = "adjust_upstream_modules")] + if let HttpTask::Header(header, end_of_stream) = &t { + self.inner + .adjust_upstream_modules(session, header, *end_of_stream, ctx) + .await?; + } session.upstream_compression.response_filter(&mut t); // check error and abort // otherwise the error is surfaced via write_response_tasks() diff --git a/pingora-proxy/src/proxy_trait.rs b/pingora-proxy/src/proxy_trait.rs index f4193fca..b81fbb9b 100644 --- a/pingora-proxy/src/proxy_trait.rs +++ b/pingora-proxy/src/proxy_trait.rs @@ -293,6 +293,39 @@ pub trait ProxyHttp { Ok(()) } + /// Adjust upstream modules before they process the response header. + /// + /// This filter is called when the upstream response header arrives, before upstream modules + /// (such as `upstream_compression`) run their response header filter. Use this to configure + /// module behavior based on the response, e.g. setting a dictionary for dictionary-based + /// content encoding. + /// + /// This filter may be called more than once per request if the upstream sends informational + /// (1xx) response headers before the final response. Implementations can check + /// [`upstream_response.status.is_informational()`](http::StatusCode::is_informational) to + /// distinguish informational headers from the final response if needed. + /// + /// `end_of_stream` indicates whether the response header is also the end of the response + /// (e.g. for HEAD responses or 304s with no body). + /// + /// The response header is provided as an immutable reference. To modify the response header + /// itself, use [`Self::upstream_response_filter()`] instead. + /// + /// This filter requires the `adjust_upstream_modules` feature to be enabled. + #[cfg(feature = "adjust_upstream_modules")] + async fn adjust_upstream_modules( + &self, + _session: &mut Session, + _upstream_response: &ResponseHeader, + _end_of_stream: bool, + _ctx: &mut Self::CTX, + ) -> Result<()> + where + Self::CTX: Send + Sync, + { + Ok(()) + } + /// Modify the response header from the upstream /// /// The modification is before caching, so any change here will be stored in the cache if enabled. diff --git a/pingora/Cargo.toml b/pingora/Cargo.toml index dd890bdb..8d2c25f7 100644 --- a/pingora/Cargo.toml +++ b/pingora/Cargo.toml @@ -126,6 +126,12 @@ time = [] ## Enable sentry for error notifications sentry = ["pingora-core/sentry"] +## Enable the `adjust_upstream_modules` filter phase on [ProxyHttp](crate::proxy::ProxyHttp) +## +## Allows configuring upstream modules (e.g. upstream compression) based on the +## response header before they process it. +adjust_upstream_modules = ["pingora-proxy?/adjust_upstream_modules"] + ## Enable pre-TLS connection filtering connection_filter = [ "pingora-core/connection_filter", From af7dd468f471dc57505975eb10e982fae0327708 Mon Sep 17 00:00:00 2001 From: Edward Wang Date: Wed, 18 Mar 2026 10:26:57 -0700 Subject: [PATCH 03/52] Don't init body reader on HEAD 1xx This prevents headers like 100-continue from ending the stream and causing hangs while the downstream is waiting. --- .bleep | 2 +- pingora-core/src/protocols/http/v1/client.rs | 291 ++++++++++++++++++- pingora-proxy/tests/test_upstream.rs | 68 +++++ 3 files changed, 353 insertions(+), 8 deletions(-) diff --git a/.bleep b/.bleep index d4078332..0e1e809f 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -e192cb4b06a03938fa92528156c6fbc681f42766 \ No newline at end of file +77a8319b16274d93fe07d158d766a822d7de7ee4 \ No newline at end of file diff --git a/pingora-core/src/protocols/http/v1/client.rs b/pingora-core/src/protocols/http/v1/client.rs index 5f9e4610..a60aad1f 100644 --- a/pingora-core/src/protocols/http/v1/client.rs +++ b/pingora-core/src/protocols/http/v1/client.rs @@ -625,13 +625,6 @@ impl HttpSession { // follow https://datatracker.ietf.org/doc/html/rfc9112#section-6.3 let preread_body = self.preread_body.as_ref().unwrap().get(&self.buf[..]); - if let Some(req) = self.request_written.as_ref() { - if req.method == http::method::Method::HEAD { - self.body_reader.init_content_length(0, preread_body); - return; - } - } - let upgraded = if let Some(code) = self.get_status() { match code.as_u16() { 101 => self.is_upgrade_req(), @@ -650,6 +643,13 @@ impl HttpSession { false }; + if let Some(req) = self.request_written.as_ref() { + if req.method == http::method::Method::HEAD { + self.body_reader.init_content_length(0, preread_body); + return; + } + } + if upgraded { self.body_reader.init_close_delimited(preread_body); self.close_delimited_resp = true; @@ -2224,6 +2224,283 @@ hello"; http_stream.respect_keepalive(); assert!(!http_stream.will_keepalive()); } + + #[tokio::test] + async fn read_informational_head_request() { + init_log(); + // HEAD request that receives 100 Continue followed by 200 OK + let wire = b"HEAD / HTTP/1.1\r\n\r\n"; + let input1 = b"HTTP/1.1 100 Continue\r\n\r\n"; + let input2 = b"HTTP/1.1 200 OK\r\nContent-Length: 100\r\n\r\n"; + + let mock_io = Builder::new() + .write(&wire[..]) + .read(&input1[..]) + .read(&input2[..]) + .build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + + // Write HEAD request + let new_request = RequestHeader::build("HEAD", b"/", None).unwrap(); + http_stream + .write_request_header(Box::new(new_request)) + .await + .unwrap(); + + // Read 100 Continue + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 100); + assert!(!eob, "100 Continue for HEAD should not signal end of body"); + } + _ => { + panic!("task should be informational header") + } + } + + // Read final 200 OK + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 200); + assert!(eob, "HEAD 200 response should signal end of body"); + } + _ => { + panic!("task should be final header") + } + } + + // Body reader should be Complete(0) for HEAD + assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(0)); + } + + #[tokio::test] + async fn read_informational_multiple_head_request() { + init_log(); + // HEAD request that receives 100 Continue, 103 Early Hints, then 200 OK + let wire = b"HEAD / HTTP/1.1\r\n\r\n"; + let input = b"HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 103 Early Hints\r\n\r\nHTTP/1.1 200 OK\r\nContent-Length: 50\r\n\r\n"; + + let mock_io = Builder::new().write(&wire[..]).read(&input[..]).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + + let new_request = RequestHeader::build("HEAD", b"/", None).unwrap(); + http_stream + .write_request_header(Box::new(new_request)) + .await + .unwrap(); + + // Read 100 Continue + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 100); + assert!(!eob, "100 Continue for HEAD should not signal end of body"); + } + _ => { + panic!("task should be 100 header") + } + } + + // Read 103 Early Hints + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 103); + assert!( + !eob, + "103 Early Hints for HEAD should not signal end of body" + ); + } + _ => { + panic!("task should be 103 header") + } + } + + // Read 200 OK — end of body + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 200); + assert!(eob, "HEAD 200 response should signal end of body"); + } + _ => { + panic!("task should be final header") + } + } + + assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(0)); + } + + #[tokio::test] + async fn read_basic_head() { + init_log(); + // Basic HEAD + 200 + let wire = b"HEAD / HTTP/1.1\r\n\r\n"; + let input = b"HTTP/1.1 200 OK\r\nContent-Length: 100\r\n\r\n"; + + let mock_io = Builder::new().write(&wire[..]).read(&input[..]).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + + let new_request = RequestHeader::build("HEAD", b"/", None).unwrap(); + http_stream + .write_request_header(Box::new(new_request)) + .await + .unwrap(); + + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 200); + assert!(eob, "HEAD 200 should be end of body"); + } + _ => { + panic!("task should be header") + } + } + + assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(0)); + + // Keepalive should work for a properly-framed HEAD response + http_stream.respect_keepalive(); + assert!(http_stream.will_keepalive()); + } + + #[tokio::test] + async fn read_head_informational_keepalive() { + init_log(); + // HEAD + 100 Continue + 200 OK, then verify keepalive is preserved. + let wire = b"HEAD / HTTP/1.1\r\n\r\n"; + let input1 = b"HTTP/1.1 100 Continue\r\n\r\n"; + let input2 = b"HTTP/1.1 200 OK\r\nContent-Length: 100\r\n\r\n"; + + let mock_io = Builder::new() + .write(&wire[..]) + .read(&input1[..]) + .read(&input2[..]) + .build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + + let new_request = RequestHeader::build("HEAD", b"/", None).unwrap(); + http_stream + .write_request_header(Box::new(new_request)) + .await + .unwrap(); + + // 100 Continue + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 100); + assert!(!eob); + } + _ => panic!("task should be informational header"), + } + + // 200 OK + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 200); + assert!(eob); + } + _ => panic!("task should be final header"), + } + + assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(0)); + + // Keepalive must still work after the 100 + 200 sequence + http_stream.respect_keepalive(); + assert!(http_stream.will_keepalive()); + } + + #[tokio::test] + async fn read_head_204() { + init_log(); + // HEAD + 204 No Content + let wire = b"HEAD / HTTP/1.1\r\n\r\n"; + let input = b"HTTP/1.1 204 No Content\r\n\r\n"; + + let mock_io = Builder::new().write(&wire[..]).read(&input[..]).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + + let new_request = RequestHeader::build("HEAD", b"/", None).unwrap(); + http_stream + .write_request_header(Box::new(new_request)) + .await + .unwrap(); + + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 204); + assert!(eob, "HEAD 204 should be end of body"); + } + _ => panic!("task should be header"), + } + + assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(0)); + } + + #[tokio::test] + async fn read_head_304() { + init_log(); + // HEAD + 304 Not Modified + let wire = b"HEAD / HTTP/1.1\r\n\r\n"; + let input = b"HTTP/1.1 304 Not Modified\r\nContent-Length: 100\r\n\r\n"; + + let mock_io = Builder::new().write(&wire[..]).read(&input[..]).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + + let new_request = RequestHeader::build("HEAD", b"/", None).unwrap(); + http_stream + .write_request_header(Box::new(new_request)) + .await + .unwrap(); + + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 304); + assert!(eob, "HEAD 304 should be end of body"); + } + _ => panic!("task should be header"), + } + + assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(0)); + } + + #[tokio::test] + async fn read_head_101_non_upgrade() { + init_log(); + // HEAD + 101 where the request is not an upgrade request. + // Contrived, but verifies the new code path: 101 check fires first, + // is_upgrade_req() returns false, then HEAD check fires. + let wire = b"HEAD / HTTP/1.1\r\n\r\n"; + let input = b"HTTP/1.1 101 Switching Protocols\r\n\r\n"; + + let mock_io = Builder::new().write(&wire[..]).read(&input[..]).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + + let new_request = RequestHeader::build("HEAD", b"/", None).unwrap(); + http_stream + .write_request_header(Box::new(new_request)) + .await + .unwrap(); + + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 101); + // HEAD without Upgrade headers → not an upgrade, body is "done" + assert!(eob, "HEAD 101 (non-upgrade) should be end of body"); + } + _ => panic!("task should be header"), + } + + assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(0)); + } } #[cfg(test)] diff --git a/pingora-proxy/tests/test_upstream.rs b/pingora-proxy/tests/test_upstream.rs index b22a1ead..e1ac37f8 100644 --- a/pingora-proxy/tests/test_upstream.rs +++ b/pingora-proxy/tests/test_upstream.rs @@ -532,6 +532,74 @@ async fn test_h2_upstream_no_end_stream_read_timeout() { } } +/// Mock origin that sends 100 Continue then a final response for any request. +/// Returns the port the server is listening on. +async fn mock_100_continue_server() -> u16 { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + + tokio::spawn(async move { + if let Ok((mut stream, _addr)) = listener.accept().await { + // Read the request (just drain it) + let mut buf = [0u8; 4096]; + let _ = stream.read(&mut buf).await.unwrap(); + + // Send 100 Continue + stream + .write_all(b"HTTP/1.1 100 Continue\r\n\r\n") + .await + .unwrap(); + // Small delay so the client reads the 100 separately + tokio::time::sleep(Duration::from_millis(100)).await; + + // Send final 200 OK with Content-Length (but no body for HEAD) + stream + .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 42\r\n\r\n") + .await + .unwrap(); + tokio::time::sleep(Duration::from_millis(100)).await; + } + }); + + port +} + +#[tokio::test] +async fn test_head_with_100_continue() { + init(); + + let port = mock_100_continue_server().await; + + let mut stream = TcpStream::connect("127.0.0.1:6147").await.unwrap(); + stream + .write_all( + format!("HEAD / HTTP/1.1\r\nHost: localhost\r\nx-port: {port}\r\n\r\n").as_bytes(), + ) + .await + .unwrap(); + + // Read through any 1xx until we get the final (non-1xx) response + let result = timeout(Duration::from_secs(5), async { + let mut resp; + let mut body; + loop { + (resp, body) = read_response_header(&mut stream).await; + if resp.status.as_u16() >= 200 { + return (resp, body); + } + } + }) + .await + .expect("should not time out waiting for final response"); + + let (resp, body) = result; + assert_eq!(resp.status.as_u16(), 200); + // HEAD responses have no body even with Content-Length + assert!(body.is_empty(), "HEAD response should have no body"); +} + mod test_cache { use super::*; use std::str::FromStr; From d0ede94894d4a8180ae75d8890caab2fbe31b08b Mon Sep 17 00:00:00 2001 From: Edward Wang Date: Wed, 4 Mar 2026 17:36:41 -0800 Subject: [PATCH 04/52] Make tracing an optional feature in pingora-cache --- .bleep | 2 +- pingora-cache/Cargo.toml | 7 +- pingora-cache/src/lib.rs | 8 ++- pingora-cache/src/lock.rs | 2 +- pingora-cache/src/memory.rs | 2 +- pingora-cache/src/put.rs | 3 +- pingora-cache/src/trace.rs | 134 +++++++++++++++++++++++++++++++++--- pingora-proxy/Cargo.toml | 1 + pingora/Cargo.toml | 1 + 9 files changed, 141 insertions(+), 19 deletions(-) diff --git a/.bleep b/.bleep index 0e1e809f..67f51e96 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -77a8319b16274d93fe07d158d766a822d7de7ee4 \ No newline at end of file +3aee79890a3f6c64716f16fe58324800909f8ff9 \ No newline at end of file diff --git a/pingora-cache/Cargo.toml b/pingora-cache/Cargo.toml index 401d827c..44a063ef 100644 --- a/pingora-cache/Cargo.toml +++ b/pingora-cache/Cargo.toml @@ -37,8 +37,8 @@ httpdate = "1.0.2" log = { workspace = true } async-trait = { workspace = true } parking_lot = "0.12" -cf-rustracing = "1.0" -cf-rustracing-jaeger = "1.0" +cf-rustracing = { version = "1.0", optional = true } +cf-rustracing-jaeger = { version = "1.0", optional = true } rmp = "0.8.14" tokio = { workspace = true } lru = { workspace = true } @@ -49,6 +49,8 @@ strum = { version = "0.26", features = ["derive"] } rand = "0.8" [dev-dependencies] +cf-rustracing = "1.0" +cf-rustracing-jaeger = "1.0" tokio-test = "0.4" tokio = { workspace = true, features = ["fs"] } env_logger = "0.11" @@ -73,3 +75,4 @@ openssl = ["pingora-core/openssl"] boringssl = ["pingora-core/boringssl"] rustls = ["pingora-core/rustls"] s2n = ["pingora-core/s2n"] +trace = ["dep:cf-rustracing", "dep:cf-rustracing-jaeger"] diff --git a/pingora-cache/src/lib.rs b/pingora-cache/src/lib.rs index 867cff08..6d13409b 100644 --- a/pingora-cache/src/lib.rs +++ b/pingora-cache/src/lib.rs @@ -16,7 +16,6 @@ #![allow(clippy::new_without_default)] -use cf_rustracing::tag::Tag; use http::{method::Method, request::Parts as ReqHeader, response::Parts as RespHeader}; use key::{CacheHashKey, CompactCacheKey, HashBinary}; use lock::WritePermit; @@ -27,7 +26,7 @@ use pingora_timeout::timeout; use std::time::{Duration, Instant, SystemTime}; use storage::MissFinishType; use strum::IntoStaticStr; -use trace::{CacheTraceCTX, Span}; +use trace::{CacheTraceCTX, Span, Tag}; pub mod cache_control; pub mod eviction; @@ -435,6 +434,7 @@ impl HttpCache { self.phase = CachePhase::Disabled(reason); self.release_write_lock(reason); // enabled_ctx will be cleared out + #[cfg_attr(not(feature = "trace"), allow(unused_mut))] let mut inner_enabled = self .inner_mut() .enabled_ctx @@ -1049,6 +1049,7 @@ impl HttpCache { inner_enabled.meta.replace(meta); + #[cfg_attr(not(feature = "trace"), allow(unused_mut))] let mut span = inner_enabled.traces.child("update_meta"); let result = inner_enabled .storage @@ -1291,6 +1292,7 @@ impl HttpCache { .enabled_ctx .as_mut() .expect("Cache enabled on cache_lookup"); + #[cfg_attr(not(feature = "trace"), allow(unused_mut))] let mut span = inner_enabled.traces.child("lookup"); let key = inner.key.as_ref().unwrap(); // safe, this phase should have cache key let now = Instant::now(); @@ -1446,6 +1448,7 @@ impl HttpCache { /// Check [Self::is_cache_locked()], panic if this request doesn't have a read lock. pub async fn cache_lock_wait(&mut self) -> LockStatus { let inner_enabled = self.inner_enabled_mut(); + #[cfg_attr(not(feature = "trace"), allow(unused_mut))] let mut span = inner_enabled.traces.child("cache_lock"); // should always call is_cache_locked() before this function, which should guarantee that // the inner cache has a read lock and lock ctx @@ -1535,6 +1538,7 @@ impl HttpCache { }) } + #[cfg_attr(not(feature = "trace"), allow(unused_mut))] async fn purge_impl( storage: &'static (dyn storage::Storage + Sync), eviction: Option<&'static (dyn eviction::EvictionManager + Sync)>, diff --git a/pingora-cache/src/lock.rs b/pingora-cache/src/lock.rs index 5633b09c..102b3380 100644 --- a/pingora-cache/src/lock.rs +++ b/pingora-cache/src/lock.rs @@ -14,8 +14,8 @@ //! Cache lock +use crate::trace::{Span, Tag}; use crate::{hashtable::ConcurrentHashTable, key::CacheHashKey, CacheKey}; -use crate::{Span, Tag}; use http::Extensions; use pingora_timeout::timeout; diff --git a/pingora-cache/src/memory.rs b/pingora-cache/src/memory.rs index 6ab57c80..e6e12acf 100644 --- a/pingora-cache/src/memory.rs +++ b/pingora-cache/src/memory.rs @@ -426,7 +426,7 @@ impl Storage for MemCache { #[cfg(test)] mod test { use super::*; - use cf_rustracing::span::Span; + use crate::trace::Span; use once_cell::sync::Lazy; fn gen_meta() -> CacheMeta { diff --git a/pingora-cache/src/put.rs b/pingora-cache/src/put.rs index fbbbb70e..94265055 100644 --- a/pingora-cache/src/put.rs +++ b/pingora-cache/src/put.rs @@ -84,6 +84,7 @@ impl CachePutCtx { } async fn put_header(&mut self, meta: CacheMeta) -> Result<()> { + #[cfg_attr(not(feature = "trace"), allow(unused_mut))] let mut trace = self.trace.child("cache put header", |o| o.start()); let miss_handler = self .storage @@ -239,7 +240,7 @@ impl CachePutCtx { #[cfg(test)] mod test { use super::*; - use cf_rustracing::span::Span; + use crate::trace::Span; use once_cell::sync::Lazy; struct TestCachePut(); diff --git a/pingora-cache/src/trace.rs b/pingora-cache/src/trace.rs index f27929a2..e8ab85cf 100644 --- a/pingora-cache/src/trace.rs +++ b/pingora-cache/src/trace.rs @@ -13,26 +13,129 @@ // limitations under the License. //! Distributed tracing helpers +//! +//! When the `trace` feature is enabled, this module re-exports the real +//! [`cf_rustracing`]/[`cf_rustracing_jaeger`] span types. +//! +//! When the `trace` feature is **disabled**, lightweight no-op shim types are +//! provided instead so that the rest of the crate compiles without pulling in +//! the tracing dependencies. -use cf_rustracing_jaeger::span::SpanContextState; use std::time::SystemTime; use crate::{CacheMeta, CachePhase, HitStatus}; -pub use cf_rustracing::tag::Tag; +// --------------------------------------------------------------------------- +// Real tracing implementation (feature = "trace") +// --------------------------------------------------------------------------- +#[cfg(feature = "trace")] +mod real { + pub use cf_rustracing::tag::Tag; -pub type Span = cf_rustracing::span::Span; -pub type SpanHandle = cf_rustracing::span::SpanHandle; + use cf_rustracing_jaeger::span::SpanContextState; -#[derive(Debug)] -pub(crate) struct CacheTraceCTX { - // parent span - pub cache_span: Span, - // only spans across multiple calls need to store here - pub miss_span: Span, - pub hit_span: Span, + pub type Span = cf_rustracing::span::Span; + pub type SpanHandle = cf_rustracing::span::SpanHandle; } +#[cfg(feature = "trace")] +pub use real::*; + +// --------------------------------------------------------------------------- +// No-op shim types (feature = "trace" disabled) +// --------------------------------------------------------------------------- +#[cfg(not(feature = "trace"))] +mod noop { + /// A no-op replacement for [`cf_rustracing::tag::Tag`]. + #[derive(Debug)] + pub struct Tag { + _priv: (), + } + + impl Tag { + /// Create a no-op tag. All arguments are ignored. + #[inline] + pub fn new(_name: N, _value: V) -> Self { + Tag { _priv: () } + } + } + + /// A no-op replacement for a rustracing `Span`. + #[derive(Debug)] + pub struct Span { + _priv: (), + } + + impl Span { + /// Return an inactive (no-op) span. + #[inline] + pub fn inactive() -> Self { + Span { _priv: () } + } + + /// Return a no-op handle. + #[inline] + pub fn handle(&self) -> SpanHandle { + SpanHandle { _priv: () } + } + + /// No-op: create a child span. + #[inline] + pub fn child(&self, _name: &'static str, _f: F) -> Span + where + F: FnOnce(SpanOptionsPlaceholder) -> SpanOptionsPlaceholder, + { + Span::inactive() + } + + /// No-op: set a single tag via a closure. + #[inline] + pub fn set_tag Tag>(&self, _f: F) {} + + /// No-op: set multiple tags via a closure. + #[inline] + pub fn set_tags(&self, _f: F) + where + F: FnOnce() -> I, + I: IntoIterator, + { + } + + /// No-op: set a finish time. + #[inline] + pub fn set_finish_time std::time::SystemTime>(&self, _f: F) {} + } + + /// Placeholder type used in [`Span::child`] closure signatures so that + /// existing call-sites like `span.child("name", |o| o.start())` compile. + #[doc(hidden)] + pub struct SpanOptionsPlaceholder { + _priv: (), + } + + impl SpanOptionsPlaceholder { + /// No-op: mirrors `SpanOptions::start()`. + #[inline] + pub fn start(self) -> Self { + self + } + } + + /// A no-op replacement for a rustracing `SpanHandle`. + #[derive(Debug)] + pub struct SpanHandle { + _priv: (), + } +} + +#[cfg(not(feature = "trace"))] +pub use noop::*; + +// --------------------------------------------------------------------------- +// Shared helpers (work with both real and no-op types) +// --------------------------------------------------------------------------- + +/// Tag a span with metadata from a [`CacheMeta`]. pub fn tag_span_with_meta(span: &mut Span, meta: &CacheMeta) { fn ts2epoch(ts: SystemTime) -> f64 { ts.duration_since(SystemTime::UNIX_EPOCH) @@ -55,6 +158,15 @@ pub fn tag_span_with_meta(span: &mut Span, meta: &CacheMeta) { }); } +#[derive(Debug)] +pub(crate) struct CacheTraceCTX { + // parent span + pub cache_span: Span, + // only spans across multiple calls need to store here + pub miss_span: Span, + pub hit_span: Span, +} + impl CacheTraceCTX { pub fn new() -> Self { CacheTraceCTX { diff --git a/pingora-proxy/Cargo.toml b/pingora-proxy/Cargo.toml index 27d98ddc..d4df1378 100644 --- a/pingora-proxy/Cargo.toml +++ b/pingora-proxy/Cargo.toml @@ -72,6 +72,7 @@ sentry = ["pingora-core/sentry"] adjust_upstream_modules = [] connection_filter = ["pingora-core/connection_filter"] prometheus = ["pingora-core/prometheus"] +trace = ["pingora-cache/trace"] [[example]] name = "connection_filter" diff --git a/pingora/Cargo.toml b/pingora/Cargo.toml index 8d2c25f7..d9fb57c2 100644 --- a/pingora/Cargo.toml +++ b/pingora/Cargo.toml @@ -153,3 +153,4 @@ document-features = [ "connection_filter" ] prometheus = ["pingora-core/prometheus"] +trace = ["pingora-cache?/trace", "pingora-proxy?/trace"] From b633683b7494d5c6cb03075c1efb12c62592c4e1 Mon Sep 17 00:00:00 2001 From: Kevin Guthrie Date: Fri, 20 Mar 2026 16:45:31 -0400 Subject: [PATCH 05/52] Fix listen fds not inherited during bootstrap_as_a_service graceful upgrade When bootstrap_as_a_service is enabled, listen_fds() was called to snapshot the fd table before BootstrapService had run, always returning None. Services would then bind fresh sockets instead of inheriting the old process's fds, breaking graceful upgrades. Fix this by eagerly allocating the ListenFds table in Bootstrap::new() so it is non-optional and already distributed to all services before bootstrap runs. When load_fds() later receives the inherited fds from the old process, it populates the same shared table in place, making them visible to all services without any re-distribution. --- .bleep | 2 +- pingora-core/src/server/bootstrap_services.rs | 13 +++--- pingora-core/src/server/mod.rs | 42 +++++++++---------- pingora-core/src/server/transfer_fd/mod.rs | 4 ++ 4 files changed, 33 insertions(+), 28 deletions(-) diff --git a/.bleep b/.bleep index 67f51e96..1a63c485 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -3aee79890a3f6c64716f16fe58324800909f8ff9 \ No newline at end of file +e8a0636c8132def54d2a5b11a618fa8ac42bd10f \ No newline at end of file diff --git a/pingora-core/src/server/bootstrap_services.rs b/pingora-core/src/server/bootstrap_services.rs index 0ad27ffc..3220e52e 100644 --- a/pingora-core/src/server/bootstrap_services.rs +++ b/pingora-core/src/server/bootstrap_services.rs @@ -58,7 +58,7 @@ pub struct Bootstrap { execution_phase_watch: broadcast::Sender, #[cfg(unix)] - listen_fds: Option, + listen_fds: ListenFds, #[cfg(feature = "sentry")] #[cfg_attr(docsrs, doc(cfg(feature = "sentry")))] @@ -95,7 +95,7 @@ impl Bootstrap { upgrade, upgrade_sock, #[cfg(unix)] - listen_fds: None, + listen_fds: Arc::new(TokioMutex::new(Fds::new())), execution_phase_watch: execution_phase_watch.clone(), completed: false, #[cfg(feature = "sentry")] @@ -186,17 +186,18 @@ impl Bootstrap { #[cfg(unix)] fn load_fds(&mut self, upgrade: bool) -> Result<(), nix::Error> { - let mut fds = Fds::new(); if upgrade { debug!("Trying to receive socks"); - fds.get_from_sock(self.upgrade_sock.as_str())? + let mut fds = Fds::new(); + fds.get_from_sock(self.upgrade_sock.as_str())?; + // Mutate through the existing Arc so all clones held by services see the update. + *self.listen_fds.blocking_lock() = fds; } - self.listen_fds = Some(Arc::new(TokioMutex::new(fds))); Ok(()) } #[cfg(unix)] - pub fn get_fds(&self) -> Option { + pub fn get_fds(&self) -> ListenFds { self.listen_fds.clone() } } diff --git a/pingora-core/src/server/mod.rs b/pingora-core/src/server/mod.rs index 80810e07..1f520f7a 100644 --- a/pingora-core/src/server/mod.rs +++ b/pingora-core/src/server/mod.rs @@ -272,8 +272,11 @@ impl Server { .send(ExecutionPhase::GracefulUpgradeTransferringFds) .ok(); - if let Some(fds) = self.listen_fds() { - let fds = fds.lock().await; + let fds = self.listen_fds(); + let fds = fds.lock().await; + if fds.is_empty() { + info!("No socks to send, shutting down."); + } else { info!("Trying to send socks"); // XXX: this is blocking IO match fds.send_to_sock(self.configuration.as_ref().upgrade_sock.as_str()) { @@ -291,24 +294,21 @@ impl Server { .send(ExecutionPhase::GracefulUpgradeCloseTimeout) .ok(); sleep(Duration::from_secs(CLOSE_TIMEOUT)).await; - info!("Broadcasting graceful shutdown"); - // gracefully exiting - match self.shutdown_watch.send(true) { - Ok(_) => { - info!("Graceful shutdown started!"); - } - Err(e) => { - error!("Graceful shutdown broadcast failed: {e}"); - // switch to fast shutdown - return ShutdownType::Graceful; - } + } + info!("Broadcasting graceful shutdown"); + // gracefully exiting + match self.shutdown_watch.send(true) { + Ok(_) => { + info!("Graceful shutdown started!"); + } + Err(e) => { + error!("Graceful shutdown broadcast failed: {e}"); + // switch to fast shutdown + return ShutdownType::Graceful; } - info!("Broadcast graceful shutdown complete"); - ShutdownType::Graceful - } else { - info!("No socks to send, shutting down."); - ShutdownType::Graceful } + info!("Broadcast graceful shutdown complete"); + ShutdownType::Graceful } } } @@ -360,14 +360,14 @@ impl Server { /// Get the configured file descriptors for listening #[cfg(unix)] - fn listen_fds(&self) -> Option { + fn listen_fds(&self) -> ListenFds { self.bootstrap.lock().get_fds() } #[allow(clippy::too_many_arguments)] fn run_service( mut service: Box, - #[cfg(unix)] fds: Option, + #[cfg(unix)] fds: ListenFds, shutdown: ShutdownWatch, threads: usize, work_stealing: bool, @@ -406,7 +406,7 @@ impl Server { service .start_service( #[cfg(unix)] - fds, + Some(fds), shutdown, listeners_per_fd, ready_notifier, diff --git a/pingora-core/src/server/transfer_fd/mod.rs b/pingora-core/src/server/transfer_fd/mod.rs index 3f852aec..a2fa58cc 100644 --- a/pingora-core/src/server/transfer_fd/mod.rs +++ b/pingora-core/src/server/transfer_fd/mod.rs @@ -50,6 +50,10 @@ impl Fds { self.map.get(bind) } + pub fn is_empty(&self) -> bool { + self.map.is_empty() + } + pub fn serialize(&self) -> (Vec, Vec) { self.map.iter().map(|(key, val)| (key.clone(), val)).unzip() } From c29014f59086c3616e5faf6fc1f96c2e2ee33b8e Mon Sep 17 00:00:00 2001 From: mariiaiurchenko Date: Tue, 17 Mar 2026 13:58:28 -0700 Subject: [PATCH 06/52] Retry on new h2 connection if spawn stream broken pipe --- .bleep | 2 +- pingora-core/src/connectors/http/v2.rs | 61 +++++++++++++++++++++++--- 2 files changed, 56 insertions(+), 7 deletions(-) diff --git a/.bleep b/.bleep index 1a63c485..0426a524 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -e8a0636c8132def54d2a5b11a618fa8ac42bd10f \ No newline at end of file +11f6d4a344e42cb6aaab1bc948b60ae2e7cc80c6 diff --git a/pingora-core/src/connectors/http/v2.rs b/pingora-core/src/connectors/http/v2.rs index c5ec42db..dd1d2b27 100644 --- a/pingora-core/src/connectors/http/v2.rs +++ b/pingora-core/src/connectors/http/v2.rs @@ -27,6 +27,7 @@ use parking_lot::{Mutex, RwLock}; use pingora_error::{Error, ErrorType::*, OrErr, Result}; use pingora_pool::{ConnectionMeta, ConnectionPool, PoolNode}; use std::collections::HashMap; +use std::io::ErrorKind; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; use std::time::Duration; @@ -152,15 +153,23 @@ impl ConnectionRef { Err(e) => { // fail to create the stream, reset the counter self.0.current_streams.fetch_sub(1, Ordering::SeqCst); - // Remote sends GOAWAY(NO_ERROR): graceful shutdown: this connection no longer - // accepts new streams. We can still try to create new connection. - if e.root_cause() + + // Check for graceful shutdown conditions where we can retry with a new connection + let is_graceful_shutdown = e + .root_cause() .downcast_ref::() .map(|e| { - e.is_go_away() && e.is_remote() && e.reason() == Some(h2::Reason::NO_ERROR) + // Remote sends GOAWAY(NO_ERROR): graceful shutdown + (e.is_go_away() && e.is_remote() && e.reason() == Some(h2::Reason::NO_ERROR)) + // Or broken pipe wrapped inside an h2::Error: stream closed unexpectedly + || (e.is_io() + && e.get_io() + .map(|io| io.kind() == ErrorKind::BrokenPipe) + .unwrap_or(false)) }) - .unwrap_or(false) - { + .unwrap_or(false); + + if is_graceful_shutdown { self.mark_shutdown(); Ok(None) } else { @@ -682,6 +691,46 @@ mod tests { assert_eq!(id, h2_5.conn.id()); } + /// `spawn_stream` must return `Ok(None)` and mark the connection as shutting + /// down when the underlying I/O channel is closed (BrokenPipe). This + /// exercises the BrokenPipe branch of `spawn_stream` directly without going + /// through the full proxy stack. + #[tokio::test] + async fn test_spawn_stream_broken_pipe_marks_shutdown() { + let (client_io, server_io) = tokio::io::duplex(65536); + let (send_req, connection) = h2::client::handshake(client_io).await.unwrap(); + let (closed_tx, closed_rx) = watch::channel(false); + let ping_timeout = Arc::new(AtomicBool::new(false)); + let conn = ConnectionRef::new(send_req, closed_rx, ping_timeout, 0, 10, Digest::default()); + + // Drive the H2 client connection task in the background. + // When the connection terminates it will signal via closed_tx. + let conn_handle = tokio::spawn(async move { + let _ = connection.await; + // Signal that the connection task has finished. + let _ = closed_tx.send(true); + }); + + // Complete the server-side H2 handshake, then drop the server connection. + // Dropping server_conn closes the write end of the duplex, so the client + // connection task will read EOF and terminate with BrokenPipe. + let server_conn = h2::server::handshake(server_io).await.unwrap(); + drop(server_conn); + + // Wait until the client connection task has fully processed the EOF. + conn_handle.await.unwrap(); + + // spawn_stream must detect BrokenPipe, mark shutdown, and return Ok(None) + // so the caller can retry on a fresh connection rather than propagating the error. + let result = conn.spawn_stream().await; + assert!(result.is_ok(), "expected Ok(None), got Err"); + assert!(result.unwrap().is_none(), "expected None stream"); + assert!( + conn.is_shutting_down(), + "connection should be marked as shutting down" + ); + } + #[tokio::test] async fn test_mark_shutdown_prevents_new_streams() { let (client_io, _server_io) = tokio::io::duplex(65536); From 63c5f21dd0ec17c38eb359f7c8c3f53828276e86 Mon Sep 17 00:00:00 2001 From: lxga Date: Mon, 9 Mar 2026 14:57:51 +0000 Subject: [PATCH 07/52] Add abort_on_close functionality to HTTP session handling This update introduces the abort_on_close feature to control behavior when a client closes the connection after the request body. When enabled (default), it results in a ConnectionClosed error, allowing the proxy to abort immediately. When disabled, the proxy can continue processing the upstream response. Includes-commit: a6420f8a5b95d73ffbb697a675fef70555807704 Replicated-from: https://github.com/cloudflare/pingora/pull/836 --- .bleep | 2 +- pingora-core/src/protocols/http/server.rs | 16 ++ pingora-core/src/protocols/http/v1/server.rs | 184 ++++++++++++++++++- 3 files changed, 193 insertions(+), 9 deletions(-) diff --git a/.bleep b/.bleep index 0426a524..f6ed5435 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -11f6d4a344e42cb6aaab1bc948b60ae2e7cc80c6 +625f135160c620972a5e6882462a34465953984c diff --git a/pingora-core/src/protocols/http/server.rs b/pingora-core/src/protocols/http/server.rs index 051cc1f4..78852939 100644 --- a/pingora-core/src/protocols/http/server.rs +++ b/pingora-core/src/protocols/http/server.rs @@ -457,6 +457,22 @@ impl Session { } } + /// Controls behaviour when the client closes the connection after the request body. + /// + /// When **enabled** (default), a client close is returned as a `ConnectionClosed` + /// error so the proxy aborts immediately. When **disabled**, `read_body_or_idle` + /// stays pending so the proxy can finish delivering the upstream response. + /// + /// Only meaningful for H1 (TCP). Noop for H2/subrequest/custom. + pub fn set_abort_on_close(&mut self, abort: bool) { + match self { + Self::H1(s) => s.set_abort_on_close(abort), + Self::H2(_) => {} + Self::Subrequest(_) => {} + Self::Custom(_) => {} + } + } + /// Return a digest of the request including the method, path and Host header // TODO: make this use a `Formatter` pub fn request_summary(&self) -> String { diff --git a/pingora-core/src/protocols/http/v1/server.rs b/pingora-core/src/protocols/http/v1/server.rs index 03ebf81f..1a5e806b 100644 --- a/pingora-core/src/protocols/http/v1/server.rs +++ b/pingora-core/src/protocols/http/v1/server.rs @@ -91,6 +91,13 @@ pub struct HttpSession { /// Set by [`HttpPersistentSettings::apply_to_session`](crate::apps::HttpPersistentSettings::apply_to_session), /// consumed by the proxy layer via [`take_connection_user_context`](Self::take_connection_user_context). connection_user_context: Option>, + /// Whether the client has closed the TCP connection (sent FIN / read returned 0). + half_closed: bool, + /// When true (default), a client close after the request body is surfaced as a + /// `ConnectionClosed` error so the proxy aborts immediately. When false, the + /// close is tolerated and `read_body_or_idle` stays pending so the proxy can + /// finish delivering the upstream response (RFC 9112 Section 9.6). + abort_on_close: bool, } impl HttpSession { @@ -132,6 +139,8 @@ impl HttpSession { close_on_response_before_downstream_finish: true, keepalive_reuses_remaining: None, connection_user_context: None, + half_closed: false, + abort_on_close: true, } } @@ -961,19 +970,48 @@ impl HttpSession { /// This function will return body bytes (same as [`Self::read_body_bytes()`]), but after /// the client body finishes (`Ok(None)` is returned), calling this function again will block /// forever, same as [`Self::idle()`]. + /// + /// By default (`abort_on_close = true`), if the client closes the connection + /// (sends TCP FIN, i.e. `read == 0`) after the request body is complete, a + /// `ConnectionClosed` error is returned. + /// + /// When `abort_on_close` is **disabled**, the close is tolerated: the future stays + /// pending so the proxy can finish delivering the upstream response via the write + /// path (per RFC 9112 Section 9.6). A true disconnect (RST) will be caught later + /// when the response write fails. pub async fn read_body_or_idle(&mut self, no_body_expected: bool) -> Result> { if no_body_expected || self.is_body_done() { + if self.half_closed { + if self.abort_on_close { + return Error::e_explain( + ConnectionClosed, + if self.response_written.is_none() { + "Prematurely before response header is sent" + } else { + "Prematurely before response body is complete" + }, + ); + } + return std::future::pending().await; + } // XXX: account for upgraded body reader change, if the read half split from the write half let read = self.idle().await?; if read == 0 { - Error::e_explain( - ConnectionClosed, - if self.response_written.is_none() { - "Prematurely before response header is sent" - } else { - "Prematurely before response body is complete" - }, - ) + self.half_closed = true; + self.set_keepalive(None); + if self.abort_on_close { + Error::e_explain( + ConnectionClosed, + if self.response_written.is_none() { + "Prematurely before response header is sent" + } else { + "Prematurely before response body is complete" + }, + ) + } else { + debug!("downstream closed (FIN), keeping write side open"); + std::future::pending().await + } } else { Error::e_explain(ConnectError, "Sent data after end of body") } @@ -982,6 +1020,11 @@ impl HttpSession { } } + /// Whether the client has half-closed the TCP connection. + pub fn is_half_closed(&self) -> bool { + self.half_closed + } + /// Return the raw bytes of the request header. pub fn get_headers_raw_bytes(&self) -> Bytes { self.raw_header.as_ref().unwrap().get_bytes(&self.buf) @@ -1079,6 +1122,18 @@ impl HttpSession { self.close_on_response_before_downstream_finish = close; } + /// Controls behaviour when the client closes the connection after the request body. + /// + /// When **enabled** (default), a client close is returned as a `ConnectionClosed` + /// error so the proxy aborts immediately. + /// + /// When **disabled**, `read_body_or_idle` stays pending on a client close so the + /// proxy can finish delivering the upstream response (RFC 9112 Section 9.6). A true + /// disconnect (RST) will surface later when the response write fails. + pub fn set_abort_on_close(&mut self, abort: bool) { + self.abort_on_close = abort; + } + /// Return the [Digest] of the connection. pub fn digest(&self) -> &Digest { &self.digest @@ -2980,3 +3035,116 @@ mod test_overread { assert!(reused.is_none()); } } + +#[cfg(test)] +mod test_abort_on_close { + use super::*; + use pingora_error::ErrorType; + use tokio_test::io::Builder; + + fn init_log() { + let _ = env_logger::builder().is_test(true).try_init(); + } + + /// Helper: create an HttpSession whose request has been read and body is done, + /// with the mock stream returning EOF on the next read (simulating client FIN). + async fn session_with_eof() -> HttpSession { + let request = b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n"; + let mock_io = Builder::new().read(&request[..]).build(); + let mut s = HttpSession::new(Box::new(mock_io)); + s.read_request().await.unwrap(); + s + } + + #[tokio::test] + async fn default_abort_on_close_returns_error() { + init_log(); + let mut s = session_with_eof().await; + + assert!(s.abort_on_close); + let err = s.read_body_or_idle(true).await.unwrap_err(); + assert_eq!(*err.etype(), ErrorType::ConnectionClosed); + assert!(s.is_half_closed()); + } + + #[tokio::test] + async fn abort_on_close_false_stays_pending() { + init_log(); + let mut s = session_with_eof().await; + s.set_abort_on_close(false); + + let result = tokio::time::timeout( + std::time::Duration::from_millis(50), + s.read_body_or_idle(true), + ) + .await; + + assert!(result.is_err(), "expected timeout (pending), got a result"); + assert!(s.is_half_closed()); + } + + #[tokio::test] + async fn abort_on_close_error_message_before_response() { + init_log(); + let mut s = session_with_eof().await; + + assert!(s.response_written().is_none()); + let err = s.read_body_or_idle(true).await.unwrap_err(); + let msg = format!("{err}"); + assert!( + msg.contains("Prematurely before response header is sent"), + "unexpected error message: {msg}" + ); + } + + #[tokio::test] + async fn abort_on_close_error_message_after_response_header() { + init_log(); + let mut s = session_with_eof().await; + + // Simulate that a response header has already been sent. + let resp = ResponseHeader::build(200, None).unwrap(); + s.response_written = Some(Box::new(resp)); + let err = s.read_body_or_idle(true).await.unwrap_err(); + let msg = format!("{err}"); + assert!( + msg.contains("Prematurely before response body is complete"), + "unexpected error message: {msg}" + ); + } + + #[tokio::test] + async fn no_body_expected_false_reads_body_then_idles() { + init_log(); + let request = b"POST / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 3\r\n\r\n"; + let mock_io = Builder::new().read(&request[..]).read(b"abc").build(); + let mut s = HttpSession::new(Box::new(mock_io)); + s.read_request().await.unwrap(); + + // 1) no_body_expected = false should still read request body while not done. + let body = s.read_body_or_idle(false).await.unwrap().unwrap(); + assert_eq!(body.as_ref(), b"abc"); + assert!(s.is_body_done()); + + // 2) Once body is naturally done, it transitions to idle behavior on the next call. + let err = s.read_body_or_idle(false).await.unwrap_err(); + assert_eq!(*err.etype(), ErrorType::ConnectionClosed); + let msg = format!("{err}"); + assert!( + msg.contains("Prematurely before response header is sent"), + "unexpected error message: {msg}" + ); + } + + #[tokio::test] + async fn set_abort_on_close_toggles() { + init_log(); + let mut s = session_with_eof().await; + + assert!(s.abort_on_close); + s.set_abort_on_close(false); + assert!(!s.abort_on_close); + s.set_abort_on_close(true); + assert!(s.abort_on_close); + } +} From 22ffdb8726f6788de1fc19b9e0ee77a84d0b75e0 Mon Sep 17 00:00:00 2001 From: Edward Wang Date: Tue, 24 Mar 2026 09:24:27 -0700 Subject: [PATCH 08/52] Add comments around pend behavior for abort_on_close --- .bleep | 2 +- pingora-core/src/protocols/http/v1/server.rs | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/.bleep b/.bleep index f6ed5435..d7c87486 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -625f135160c620972a5e6882462a34465953984c +1cae16f3eded4ad3c408941fdc7ddf0697660303 diff --git a/pingora-core/src/protocols/http/v1/server.rs b/pingora-core/src/protocols/http/v1/server.rs index 1a5e806b..7e648ca5 100644 --- a/pingora-core/src/protocols/http/v1/server.rs +++ b/pingora-core/src/protocols/http/v1/server.rs @@ -979,6 +979,10 @@ impl HttpSession { /// pending so the proxy can finish delivering the upstream response via the write /// path (per RFC 9112 Section 9.6). A true disconnect (RST) will be caught later /// when the response write fails. + /// + /// Note that this marks the connection as half-closed if FIN is detected. If this function + /// is called after the connection is already marked half-closed and `abort_on_close` is + /// **disabled**, then it will pend forever. pub async fn read_body_or_idle(&mut self, no_body_expected: bool) -> Result> { if no_body_expected || self.is_body_done() { if self.half_closed { @@ -1010,6 +1014,8 @@ impl HttpSession { ) } else { debug!("downstream closed (FIN), keeping write side open"); + // If the connection is fully closed, writing the response side + // will fail. std::future::pending().await } } else { From c4beff8fd408064f360eb7893e50ddab31a365d1 Mon Sep 17 00:00:00 2001 From: Matthew Gumport Date: Wed, 25 Mar 2026 23:47:37 +0000 Subject: [PATCH 09/52] expose pipe_subrequest outcome Add fields such that callers can distinguish a successful subrequest from one that died silently or was cut short. The handle lets callers await post-response cleanup (cache writes, logging) before issuing the next subrequest. --- .bleep | 2 +- pingora-proxy/src/subrequest/pipe.rs | 111 +++++++++++++++++++++++---- 2 files changed, 96 insertions(+), 17 deletions(-) diff --git a/.bleep b/.bleep index d7c87486..9d8d04a5 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -1cae16f3eded4ad3c408941fdc7ddf0697660303 +4b3daa5fc9401e73ccf62698fa4d64969bafb2ad diff --git a/pingora-proxy/src/subrequest/pipe.rs b/pingora-proxy/src/subrequest/pipe.rs index 6dd4a57e..279a89de 100644 --- a/pingora-proxy/src/subrequest/pipe.rs +++ b/pingora-proxy/src/subrequest/pipe.rs @@ -42,16 +42,28 @@ pub enum InputBodyType { SaveBody(usize), } -/// Context struct as a result of subrequest piping. -#[derive(Clone)] +/// Outcome of [`pipe_subrequest`]. +#[derive(Debug, Default)] pub struct PipeSubrequestState { - /// The saved (captured) body from the main session. + /// Captured body from the main session. pub saved_body: Option, + /// Did the subrequest produce a response header? Checked before the task + /// filter runs, so a filtered-out header still counts. + pub header_received: bool, + /// The spawned subrequest task handle. Always set after spawn. Caller is + /// responsible for awaiting/inspecting state. + pub join_handle: Option>, } impl PipeSubrequestState { - fn new() -> PipeSubrequestState { - PipeSubrequestState { saved_body: None } + /// Creates a snapshot for error reporting, excluding the join handle. + /// Used by [`map_pipe_err`] to capture state at the point of failure. + pub fn snapshot_for_error(&self) -> Self { + PipeSubrequestState { + saved_body: self.saved_body.clone(), + header_received: self.header_received, + join_handle: None, + } } } @@ -81,7 +93,7 @@ fn map_pipe_err>>( from_subreq: bool, state: &PipeSubrequestState, ) -> Result { - result.map_err(|e| PipeSubrequestError::new(e, from_subreq, state.clone())) + result.map_err(|e| PipeSubrequestError::new(e, from_subreq, state.snapshot_for_error())) } #[derive(Debug, Clone)] @@ -182,13 +194,13 @@ where }; let mut downstream_state = DownstreamStateMachine::new(no_body_input); - let mut state = PipeSubrequestState::new(); - state.saved_body = saved_body; + let mut state = PipeSubrequestState { + saved_body, + ..Default::default() + }; - // Have the subrequest remove all body-related headers if no body will be sent - // TODO: we could also await the join handle, but subrequest may be running logging phase - // also the full run() may also await cache fill if downstream fails - let _join_handle = tokio::spawn(async move { + // Remove headers if no body. + let join_handle = tokio::spawn(async move { if no_body_input { subrequest .session_mut() @@ -196,8 +208,9 @@ where .expect("PreparedSubrequest must be subrequest") .clear_request_body_headers(); } - subrequest.run().await + let _ = subrequest.run().await; }); + state.join_handle = Some(join_handle); let tx = subrequest_handle.tx; let mut rx = subrequest_handle.rx; @@ -219,6 +232,10 @@ where task = rx.recv(), if !response_state.upstream_done() => { debug!("upstream event: {:?}", task); if let Some(t) = task { + // Did the subrequest get headers? + if matches!(&t, HttpTask::Header(..)) { + state.header_received = true; + } // pull as many tasks as we can const TASK_BUFFER_SIZE: usize = 4; let mut tasks = Vec::with_capacity(TASK_BUFFER_SIZE); @@ -229,6 +246,9 @@ where // tokio::task::unconstrained because now_or_never may yield None when the future is ready while let Some(maybe_task) = tokio::task::unconstrained(rx.recv()).now_or_never() { if let Some(t) = maybe_task { + if matches!(&t, HttpTask::Header(..)) { + state.header_received = true; + } let task = map_pipe_err(task_filter(t), false, &state)?; if let Some(filtered) = task { tasks.push(filtered); @@ -248,9 +268,7 @@ where // (can only happen with a real session, TODO to allow with preset body) downstream_state.maybe_finished(!use_preset_body && session.is_body_done()); } else { - // quite possible that the subrequest may be finished, though the main session - // is not - we still must exit in this case - debug!("empty upstream event"); + debug!("upstream channel closed early"); response_state.maybe_set_upstream_done(true); } }, @@ -397,3 +415,64 @@ fn do_send_body_to_pipe( Ok(end_of_body) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::subrequest::Ctx as SubrequestCtx; + use crate::{Session, Subrequest, SubrequestSpawner}; + use async_trait::async_trait; + use pingora_core::protocols::http::ServerSession as HttpSession; + use std::sync::Arc; + + /// Drops session without producing output — channels close, rx returns None. + struct NoopApp; + + #[async_trait] + impl Subrequest for NoopApp { + async fn process_subrequest( + self: Arc, + _session: Box, + _ctx: Box, + ) { + } + } + + async fn mock_session() -> Session { + let input = b"GET / HTTP/1.1\r\nHost: test\r\n\r\n"; + let mock_io = tokio_test::io::Builder::new().read(&input[..]).build(); + let mut session = Session::new_h1(Box::new(mock_io) as pingora_core::protocols::Stream); + session + .downstream_session + .read_request() + .await + .expect("mock request should parse"); + session + } + + #[tokio::test] + async fn no_header_received_when_subrequest_exits_silently() { + let mut session = mock_session().await; + + let spawner = SubrequestSpawner::new(Arc::new(NoopApp)); + let ctx = SubrequestCtx::builder().body_mode(BodyMode::NoBody).build(); + let (subrequest, handle) = spawner.create_subrequest(session.as_downstream(), ctx); + + let result = pipe_subrequest( + &mut session, + subrequest, + handle, + |task| Ok(Some(task)), + InputBodyType::Preset(InputBody::NoBody), + ) + .await; + + let state = + result.unwrap_or_else(|e| panic!("pipe should return Ok, not Err: {:?}", e.error)); + assert!( + !state.header_received, + "no header should have been received from the no-op subrequest" + ); + assert!(state.join_handle.is_some(), "task handle should be set"); + } +} From 542129fc9cf35c348e13bdbae3118bfe2a214f9b Mon Sep 17 00:00:00 2001 From: Kevin Guthrie Date: Mon, 23 Mar 2026 09:46:25 -0400 Subject: [PATCH 10/52] Fix flaky tests: test_tls_psk, test_conn_timeout, test_1xx_caching, listener port collisions test_conn_timeout / test_conn_timeout_with_offload: Replace 192.0.2.1 (TEST-~~~) with a bound-but-not-listening local socket via the new timeout_socket() helper in utils::for_testing. Because listen() is never called, the kernel silently drops SYN packets, guaranteeing a real ConnectTimedout on Linux. The total_connection_timeout tests still use 192.0.2.1 (SEMI_BLACKHOLE) since they test error classification and accept ConnectNoRoute as an alternative. test_tls_psk (s2n): PskTlsServer::start() spawned a background thread with no readiness check. Use an mpsc channel to signal after TcpListener::bind so tests only proceed once the port is ready. Also make the accept loop resilient to handshake failures (continue instead of panic) so a stale probe cannot take down the server. test_1xx_caching: mock_1xx_server used fixed ports (6151/6152) and sleep(100ms) for readiness. Refactored to spawn_mock_1xx_server which binds to port 0 (OS-assigned) and signals readiness via a oneshot channel after bind. Eliminates AddrInUse from TIME_WAIT and sleep races. test_listen_tcp / test_listen_tcp_ipv6_only: Hardcoded ports 7100-7102 collided across parallel CI test jobs. Switch to port 0 with the new ListenerEndpoint::local_addr() / Listener::local_addr() methods to discover the actual bound port. --- .bleep | 2 +- pingora-core/src/connectors/l4.rs | 23 ++++---- pingora-core/src/connectors/mod.rs | 61 +++++++++++++-------- pingora-core/src/listeners/l4.rs | 23 +++++--- pingora-core/src/listeners/mod.rs | 27 +++++----- pingora-core/src/protocols/l4/listener.rs | 14 +++++ pingora-proxy/tests/test_upstream.rs | 65 +++++++++++------------ pingora-proxy/tests/utils/mock_origin.rs | 21 ++++++-- pingora-proxy/tests/utils/server_utils.rs | 36 ++++++++++--- 9 files changed, 174 insertions(+), 98 deletions(-) diff --git a/.bleep b/.bleep index 9d8d04a5..1226d86e 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -4b3daa5fc9401e73ccf62698fa4d64969bafb2ad +2720504e0767063adc61da05ab8c7d34afa8671e diff --git a/pingora-core/src/connectors/l4.rs b/pingora-core/src/connectors/l4.rs index bd7439d4..d3baaa63 100644 --- a/pingora-core/src/connectors/l4.rs +++ b/pingora-core/src/connectors/l4.rs @@ -313,15 +313,9 @@ mod tests { use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time::{Duration, Instant}; - use tokio::io::AsyncWriteExt; use tokio::time::sleep; - /// Some of the tests below are flaky when making new connections to mock - /// servers. The servers are simple tokio listeners, so failures there are - /// not indicative of real errors. This function will retry the peer/server - /// in increasing intervals until it either succeeds in connecting or a long - /// timeout expires (max 10sec) - #[cfg(unix)] + #[cfg(target_os = "linux")] async fn wait_for_peer

(peer: &P) where P: Peer + Send + Sync, @@ -394,11 +388,17 @@ mod tests { #[tokio::test] async fn test_conn_timeout() { - // 192.0.2.1 is effectively a blackhole + // 192.0.2.1 is TEST-NET-1 (RFC 5737) — SYN packets are silently + // dropped on Linux, producing ConnectTimedout. On macOS the kernel + // may instead return ENETUNREACH (ConnectNoRoute). let mut peer = BasicPeer::new("192.0.2.1:79"); - peer.options.connection_timeout = Some(std::time::Duration::from_millis(1)); //1ms - let new_session = connect(&peer, None).await; - assert_eq!(new_session.unwrap_err().etype(), &ConnectTimedout) + peer.options.connection_timeout = Some(Duration::from_millis(1)); + let err = connect(&peer, None).await.unwrap_err(); + assert!( + err.etype() == &ConnectTimedout || err.etype() == &ConnectNoRoute, + "unexpected error type: {:?}", + err.etype() + ); } #[tokio::test] @@ -537,6 +537,7 @@ mod tests { // one-off mock server async fn mock_inet_connect_server() -> u16 { + use tokio::io::AsyncWriteExt; use tokio::net::TcpListener; let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); diff --git a/pingora-core/src/connectors/mod.rs b/pingora-core/src/connectors/mod.rs index 3e3c1c46..e5e987cb 100644 --- a/pingora-core/src/connectors/mod.rs +++ b/pingora-core/src/connectors/mod.rs @@ -482,15 +482,14 @@ pub(crate) mod test_utils { #[cfg(test)] #[cfg(feature = "any_tls")] mod tests { + use std::time::Duration; + use pingora_error::ErrorType; use tls::Connector; use super::*; use crate::upstreams::peer::BasicPeer; - // 192.0.2.1 is effectively a black hole - const BLACK_HOLE: &str = "192.0.2.1:79"; - #[tokio::test] async fn test_connect() { let connector = TransportConnector::new(None); @@ -547,15 +546,23 @@ mod tests { server_handle.await.unwrap(); } + // 192.0.2.1 is TEST-NET-1 (RFC 5737) — SYN packets are silently + // dropped on Linux, producing ConnectTimedout. On macOS the kernel + // may instead return ENETUNREACH (ConnectNoRoute). + const BLACKHOLE: &str = "192.0.2.1:79"; + async fn do_test_conn_timeout(conf: Option) { let connector = TransportConnector::new(conf); - let mut peer = BasicPeer::new(BLACK_HOLE); - peer.options.connection_timeout = Some(std::time::Duration::from_millis(1)); - let stream = connector.new_stream(&peer).await; - match stream { - Ok(_) => panic!("should throw an error"), - Err(e) => assert_eq!(e.etype(), &ConnectTimedout), - } + let mut peer = BasicPeer::new(BLACKHOLE); + peer.options.connection_timeout = Some(Duration::from_millis(1)); + let Err(e) = connector.new_stream(&peer).await else { + panic!("should throw an error"); + }; + assert!( + e.etype() == &ConnectTimedout || e.etype() == &ConnectNoRoute, + "unexpected error type: {:?}", + e.etype() + ); } #[tokio::test] @@ -585,8 +592,8 @@ mod tests { } /// Helper function for testing error handling in the `do_connect` function. - /// This assumes that the connection will fail to on the peer and returns - /// the decomposed error type and message + /// This assumes that the connection will fail on the peer and returns + /// the decomposed error type and message. async fn get_do_connect_failure_with_peer(peer: &BasicPeer) -> (ErrorType, String) { let tls_connector = Connector::new(None); let stream = do_connect(peer, None, None, &tls_connector.ctx).await; @@ -604,26 +611,36 @@ mod tests { #[tokio::test] async fn test_do_connect_with_total_timeout() { - let mut peer = BasicPeer::new(BLACK_HOLE); - peer.options.total_connection_timeout = Some(std::time::Duration::from_millis(1)); + let mut peer = BasicPeer::new(BLACKHOLE); + peer.options.total_connection_timeout = Some(Duration::from_millis(1)); let (etype, context) = get_do_connect_failure_with_peer(&peer).await; - assert_eq!(etype, ConnectTimedout); - assert!(context.contains("total-connection timeout")); + assert!( + etype == ConnectTimedout || etype == ConnectNoRoute, + "unexpected error type: {etype:?}" + ); + if etype == ConnectTimedout { + assert!(context.contains("total-connection timeout")); + } } #[tokio::test] async fn test_tls_connect_timeout_supersedes_total() { - let mut peer = BasicPeer::new(BLACK_HOLE); - peer.options.total_connection_timeout = Some(std::time::Duration::from_millis(10)); - peer.options.connection_timeout = Some(std::time::Duration::from_millis(1)); + let mut peer = BasicPeer::new(BLACKHOLE); + peer.options.total_connection_timeout = Some(Duration::from_millis(10)); + peer.options.connection_timeout = Some(Duration::from_millis(1)); let (etype, context) = get_do_connect_failure_with_peer(&peer).await; - assert_eq!(etype, ConnectTimedout); - assert!(!context.contains("total-connection timeout")); + assert!( + etype == ConnectTimedout || etype == ConnectNoRoute, + "unexpected error type: {etype:?}" + ); + if etype == ConnectTimedout { + assert!(!context.contains("total-connection timeout")); + } } #[tokio::test] async fn test_do_connect_without_total_timeout() { - let peer = BasicPeer::new(BLACK_HOLE); + let peer = BasicPeer::new(BLACKHOLE); let (etype, context) = get_do_connect_failure_with_peer(&peer).await; assert!(etype != ConnectTimedout || !context.contains("total-connection timeout")); } diff --git a/pingora-core/src/listeners/l4.rs b/pingora-core/src/listeners/l4.rs index 1c0052f8..b965ad6f 100644 --- a/pingora-core/src/listeners/l4.rs +++ b/pingora-core/src/listeners/l4.rs @@ -386,6 +386,15 @@ impl ListenerEndpoint { self.listen_addr.as_ref() } + /// Return the local address this endpoint is bound to. + /// + /// Useful when the listener was bound to port 0 (OS-assigned) to + /// discover the actual port. + #[cfg(test)] + pub fn local_addr(&self) -> Option { + self.listener.local_addr() + } + fn apply_stream_settings(&self, stream: &mut Stream) -> Result<()> { // settings are applied based on whether the underlying stream supports it stream.set_nodelay()?; @@ -470,11 +479,9 @@ mod test { #[tokio::test] async fn test_listen_tcp() { - let addr = "127.0.0.1:7100"; - let mut builder = ListenerEndpoint::builder(); - builder.listen_addr(ServerAddress::Tcp(addr.into(), None)); + builder.listen_addr(ServerAddress::Tcp("127.0.0.1:0".into(), None)); #[cfg(unix)] let listener = builder.listen(None).await.unwrap(); @@ -482,6 +489,8 @@ mod test { #[cfg(windows)] let listener = builder.listen().await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { // just try to accept once listener.accept().await.unwrap(); @@ -500,7 +509,7 @@ mod test { let mut builder = ListenerEndpoint::builder(); - builder.listen_addr(ServerAddress::Tcp("[::]:7101".into(), sock_opt)); + builder.listen_addr(ServerAddress::Tcp("[::]:0".into(), sock_opt)); #[cfg(unix)] let listener = builder.listen(None).await.unwrap(); @@ -508,15 +517,17 @@ mod test { #[cfg(windows)] let listener = builder.listen().await.unwrap(); + let port = listener.local_addr().unwrap().port(); + tokio::spawn(async move { // just try to accept twice listener.accept().await.unwrap(); listener.accept().await.unwrap(); }); - tokio::net::TcpStream::connect("127.0.0.1:7101") + tokio::net::TcpStream::connect(format!("127.0.0.1:{port}")) .await .expect_err("cannot connect to v4 addr"); - tokio::net::TcpStream::connect("[::1]:7101") + tokio::net::TcpStream::connect(format!("[::1]:{port}")) .await .expect("can connect to v6 addr"); } diff --git a/pingora-core/src/listeners/mod.rs b/pingora-core/src/listeners/mod.rs index e44f1735..f2e649f8 100644 --- a/pingora-core/src/listeners/mod.rs +++ b/pingora-core/src/listeners/mod.rs @@ -337,14 +337,11 @@ mod test { #[cfg(feature = "any_tls")] use tokio::io::AsyncWriteExt; use tokio::net::TcpStream; - use tokio::time::{sleep, Duration}; #[tokio::test] async fn test_listen_tcp() { - let addr1 = "127.0.0.1:7101"; - let addr2 = "127.0.0.1:7102"; - let mut listeners = Listeners::tcp(addr1); - listeners.add_tcp(addr2); + let mut listeners = Listeners::tcp("127.0.0.1:0"); + listeners.add_tcp("127.0.0.1:0"); let listeners = listeners .build( @@ -355,6 +352,10 @@ mod test { .unwrap(); assert_eq!(listeners.len(), 2); + let addrs: Vec<_> = listeners + .iter() + .map(|s| s.l4.local_addr().unwrap()) + .collect(); for listener in listeners { tokio::spawn(async move { // just try to accept once @@ -363,11 +364,12 @@ mod test { }); } - // make sure the above starts before the lines below - sleep(Duration::from_millis(10)).await; - - TcpStream::connect(addr1).await.unwrap(); - TcpStream::connect(addr2).await.unwrap(); + // The listeners are already bound (port resolved during build()), + // so the kernel accepts connections into the backlog immediately. + // No readiness wait needed — connect will succeed as soon as the + // OS has completed the TCP handshake. + TcpStream::connect(addrs[0]).await.unwrap(); + TcpStream::connect(addrs[1]).await.unwrap(); } #[tokio::test] @@ -400,9 +402,8 @@ mod test { .await .unwrap(); }); - // make sure the above starts before the lines below - sleep(Duration::from_millis(10)).await; - + // The listener is already bound, so the kernel accepts connections + // into the backlog immediately. No readiness wait needed. let client = reqwest::Client::builder() .danger_accept_invalid_certs(true) .build() diff --git a/pingora-core/src/protocols/l4/listener.rs b/pingora-core/src/protocols/l4/listener.rs index 7d00005e..a6055267 100644 --- a/pingora-core/src/protocols/l4/listener.rs +++ b/pingora-core/src/protocols/l4/listener.rs @@ -67,6 +67,20 @@ impl AsRawSocket for Listener { } impl Listener { + /// Return the local address this listener is bound to. + /// + /// For TCP listeners this is the resolved address (including the + /// OS-assigned port when the listener was bound to port 0). + /// Returns `None` for non-TCP listeners (e.g. Unix domain sockets). + #[cfg(test)] + pub fn local_addr(&self) -> Option { + match self { + Self::Tcp(l) => l.local_addr().ok(), + #[cfg(unix)] + Self::Unix(_) => None, + } + } + /// Accept a connection from the listening endpoint pub async fn accept(&self) -> io::Result { match &self { diff --git a/pingora-proxy/tests/test_upstream.rs b/pingora-proxy/tests/test_upstream.rs index e1ac37f8..9ae4511e 100644 --- a/pingora-proxy/tests/test_upstream.rs +++ b/pingora-proxy/tests/test_upstream.rs @@ -482,8 +482,9 @@ async fn test_h2_upstream_no_end_stream_read_timeout() { } }); - tokio::time::sleep(Duration::from_millis(50)).await; - + // The listener was bound before the spawn (line 426), so the kernel + // is already accepting connections into the backlog. No readiness + // wait needed. let client = reqwest::Client::new(); let url = "http://127.0.0.1:6147/test"; @@ -1379,35 +1380,37 @@ mod test_cache { // set up a one-off mock server // (warp / hyper don't have custom 1xx sending capabilities yet) - async fn mock_1xx_server(port: u16, cc_header: &str) { - use tokio::io::AsyncWriteExt; - - let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) - .await - .unwrap(); - if let Ok((mut stream, _addr)) = listener.accept().await { - stream.write_all(b"HTTP/1.1 103 Early Hints\r\nLink: ; rel=preconnect\r\n\r\n").await.unwrap(); - // wait a bit so that the client can read - sleep(Duration::from_millis(100)).await; - stream.write_all(format!("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nCache-Control: {}\r\n\r\nhello", cc_header).as_bytes()).await.unwrap(); - sleep(Duration::from_millis(100)).await; - } + // One-shot mock server that sends a 103 Early Hints then a final 200. + // Binds to port 0 (OS-assigned) and returns the actual port via a + // oneshot channel once the listener is ready. + fn spawn_mock_1xx_server(cc_header: &'static str) -> tokio::sync::oneshot::Receiver { + let (tx, rx) = tokio::sync::oneshot::channel(); + tokio::spawn(async move { + use tokio::io::AsyncWriteExt; + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let _ = tx.send(port); // signal: port is bound + if let Ok((mut stream, _addr)) = listener.accept().await { + stream.write_all(b"HTTP/1.1 103 Early Hints\r\nLink: ; rel=preconnect\r\n\r\n").await.unwrap(); + sleep(Duration::from_millis(100)).await; + stream.write_all(format!("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nCache-Control: {}\r\n\r\nhello", cc_header).as_bytes()).await.unwrap(); + sleep(Duration::from_millis(100)).await; + } + }); + rx } init(); let url = "http://127.0.0.1:6148/unique/test_1xx_caching"; - tokio::spawn(async { - mock_1xx_server(6151, "max-age=5").await; - }); - // wait for server to start - sleep(Duration::from_millis(100)).await; + let port = spawn_mock_1xx_server("max-age=5").await.unwrap(); let client = reqwest::Client::new(); let res = client .get(url) - .header("x-port", "6151") + .header("x-port", port.to_string()) .send() .await .unwrap(); @@ -1416,9 +1419,10 @@ mod test_cache { assert_eq!(headers["x-cache-status"], "miss"); assert_eq!(res.text().await.unwrap(), "hello"); + // Second request to the same URL should be a cache hit (no server needed) let res = client .get(url) - .header("x-port", "6151") + .header("x-port", port.to_string()) .send() .await .unwrap(); @@ -1430,15 +1434,11 @@ mod test_cache { // 1xx shouldn't interfere with bypass let url = "http://127.0.0.1:6148/unique/test_1xx_bypass"; - tokio::spawn(async { - mock_1xx_server(6152, "private, no-store").await; - }); - // wait for server to start - sleep(Duration::from_millis(100)).await; + let port = spawn_mock_1xx_server("private, no-store").await.unwrap(); let res = client .get(url) - .header("x-port", "6152") + .header("x-port", port.to_string()) .send() .await .unwrap(); @@ -1448,16 +1448,11 @@ mod test_cache { assert_eq!(res.text().await.unwrap(), "hello"); // restart the one-off server - still uncacheable - sleep(Duration::from_millis(100)).await; - tokio::spawn(async { - mock_1xx_server(6152, "private, no-store").await; - }); - // wait for server to start - sleep(Duration::from_millis(100)).await; + let port = spawn_mock_1xx_server("private, no-store").await.unwrap(); let res = client .get(url) - .header("x-port", "6152") + .header("x-port", port.to_string()) .send() .await .unwrap(); diff --git a/pingora-proxy/tests/utils/mock_origin.rs b/pingora-proxy/tests/utils/mock_origin.rs index 74840e19..fa5a327f 100644 --- a/pingora-proxy/tests/utils/mock_origin.rs +++ b/pingora-proxy/tests/utils/mock_origin.rs @@ -59,7 +59,22 @@ fn init() -> bool { .output() .unwrap(); }); - // wait until the server is up - thread::sleep(time::Duration::from_secs(2)); - true + // Wait until openresty is accepting connections, then give it a moment + // to finish worker initialization. + let deadline = time::Instant::now() + time::Duration::from_secs(10); + while time::Instant::now() < deadline { + if std::net::TcpStream::connect_timeout( + &"127.0.0.1:8000".parse().unwrap(), + time::Duration::from_millis(100), + ) + .is_ok() + { + // Port is listening; allow a brief window for workers to finish + // initializing before tests start sending real requests. + thread::sleep(time::Duration::from_millis(500)); + return true; + } + thread::sleep(time::Duration::from_millis(50)); + } + panic!("mock origin (openresty) failed to start within 10s"); } diff --git a/pingora-proxy/tests/utils/server_utils.rs b/pingora-proxy/tests/utils/server_utils.rs index 5a418934..9361182e 100644 --- a/pingora-proxy/tests/utils/server_utils.rs +++ b/pingora-proxy/tests/utils/server_utils.rs @@ -873,16 +873,28 @@ pub struct PskTlsServer { #[cfg(feature = "s2n")] impl PskTlsServer { pub fn start() -> Self { - let server_handle = thread::spawn(|| { + use std::sync::mpsc; + use std::time::Duration; + + // Use a channel to wait for the server to bind its port. + // A TCP probe can't be used here because the TLS acceptor would + // try to handshake the probe connection, fail, and panic. + let (tx, rx) = mpsc::channel(); + let server_handle = thread::spawn(move || { let rt = tokio::runtime::Runtime::new().unwrap(); - rt.block_on(Self::run_server()); + rt.block_on(Self::run_server(tx)); }); + + // Wait up to 10s for the server to signal it has bound the port. + rx.recv_timeout(Duration::from_secs(10)) + .expect("PSK TLS server failed to start within 10s"); + PskTlsServer { handle: server_handle, } } - async fn run_server() { + async fn run_server(ready_tx: std::sync::mpsc::Sender<()>) { use pingora_core::{protocols::tls::S2NConnectionBuilder, tls::TlsAcceptor}; use pingora_core::{ protocols::tls::{Psk, PskConfig, PskType}, @@ -899,6 +911,8 @@ impl PskTlsServer { let addr: std::net::SocketAddr = "127.0.0.1:6151".parse().unwrap(); let listener = TcpListener::bind(addr).await.unwrap(); + let _ = ready_tx.send(()); // signal: port is bound + let mut config_builder = Config::builder(); unsafe { config_builder.disable_x509_verification(); @@ -915,12 +929,20 @@ impl PskTlsServer { let acceptor = TlsAcceptor::new(connection_builder); loop { - use tokio::{io::AsyncWriteExt, net::tcp}; + use tokio::io::AsyncWriteExt; let (tcp_stream, _) = listener.accept().await.unwrap(); - let mut stream = acceptor.clone().accept(tcp_stream).await.unwrap(); + // Don't panic on handshake failure — a stale connection or probe + // shouldn't take down the server for subsequent real connections. + let mut stream = match acceptor.clone().accept(tcp_stream).await { + Ok(s) => s, + Err(e) => { + log::warn!("PSK TLS server: handshake failed: {e}"); + continue; + } + }; let response = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello"; - stream.write_all(response).await.unwrap(); - stream.shutdown().await; + let _ = stream.write_all(response).await; + let _ = stream.shutdown().await; } } } From 1d9371191862d25d9314ad299a0ef8d3e514600c Mon Sep 17 00:00:00 2001 From: Kevin Guthrie Date: Wed, 25 Mar 2026 13:04:26 -0400 Subject: [PATCH 11/52] Replace tokio::sync::Mutex with parking_lot::Mutex for ListenFds ListenFds only guards an in-memory fd table and a blocking send_to_sock call, neither of which benefit from an async mutex. Switch to parking_lot::Mutex and move the fd-send path in main_loop onto the blocking thread pool via spawn_blocking. Because the parking_lot lock cannot be held across bind().await in ListenerEndpointBuilder::listen(), introduce a global per-address async lock map (flurry::HashMap>>) that serializes the check-bind-insert sequence for each address. This prevents two concurrent callers from racing to bind the same address while the ListenFds lock is released. --- .bleep | 2 +- pingora-core/Cargo.toml | 1 + pingora-core/src/listeners/l4.rs | 47 ++++++++++++++++--- pingora-core/src/server/bootstrap_services.rs | 6 +-- pingora-core/src/server/mod.rs | 40 +++++++++------- 5 files changed, 67 insertions(+), 29 deletions(-) diff --git a/.bleep b/.bleep index 1226d86e..2911a4d3 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -2720504e0767063adc61da05ab8c7d34afa8671e +67cc768b717d3865a73bcd917c905d7d9aeb4c62 diff --git a/pingora-core/Cargo.toml b/pingora-core/Cargo.toml index 947a92cf..e2854966 100644 --- a/pingora-core/Cargo.toml +++ b/pingora-core/Cargo.toml @@ -76,6 +76,7 @@ daggy = "0.8" [target.'cfg(unix)'.dependencies] daemonize = "0.5.0" +flurry = "0.5" nix = "~0.24.3" [target.'cfg(windows)'.dependencies] diff --git a/pingora-core/src/listeners/l4.rs b/pingora-core/src/listeners/l4.rs index b965ad6f..5532b635 100644 --- a/pingora-core/src/listeners/l4.rs +++ b/pingora-core/src/listeners/l4.rs @@ -44,6 +44,20 @@ use crate::protocols::GetSocketDigest; use crate::protocols::TcpKeepalive; #[cfg(unix)] use crate::server::ListenFds; +#[cfg(unix)] +use std::sync::LazyLock; + +/// Per-address async lock map for serializing the check-bind-insert sequence +/// in [`ListenerEndpointBuilder::listen`]. +/// +/// With `ListenFds` using a synchronous `parking_lot::Mutex`, the lock cannot +/// be held across `bind().await`. This global map ensures that only one task at +/// a time can be in the process of looking up, binding, and inserting a given +/// address — preventing two concurrent callers from both seeing "not found" and +/// racing to bind the same address. +#[cfg(unix)] +static BIND_LOCKS: LazyLock>>> = + LazyLock::new(flurry::HashMap::new); const TCP_LISTENER_MAX_TRY: usize = 30; const TCP_LISTENER_TRY_STEP: Duration = Duration::from_secs(1); @@ -326,19 +340,38 @@ impl ListenerEndpointBuilder { let listener = if let Some(fds_table) = fds { let addr_str = listen_addr.as_ref(); - // consider make this mutex std::sync::Mutex or OnceCell - let mut table = fds_table.lock().await; + // Acquire a per-address async lock so that only one task at a + // time can go through the check-bind-insert sequence for a given + // address. The flurry guard is dropped before the await so its + // !Send pointer does not cross an await point. + let addr_lock = { + let guard = BIND_LOCKS.pin(); + match guard.get(addr_str) { + Some(existing) => existing.clone(), + None => { + let new_lock = Arc::new(tokio::sync::Mutex::new(())); + match guard.try_insert(addr_str.to_string(), new_lock.clone()) { + Ok(inserted) => inserted.clone(), + Err(e) => e.current.clone(), + } + } + } + }; + let _guard = addr_lock.lock().await; + + let existing_fd = fds_table.lock().get(addr_str).copied(); - if let Some(fd) = table.get(addr_str) { - from_raw_fd(&listen_addr, *fd)? + if let Some(fd) = existing_fd { + from_raw_fd(&listen_addr, fd)? } else { - // not found let listener = bind(&listen_addr).await?; - table.add(addr_str.to_string(), listener.as_raw_fd()); + fds_table + .lock() + .add(addr_str.to_string(), listener.as_raw_fd()); listener } } else { - // not found, no fd table + // no fd table bind(&listen_addr).await? }; diff --git a/pingora-core/src/server/bootstrap_services.rs b/pingora-core/src/server/bootstrap_services.rs index 3220e52e..ca5bfad0 100644 --- a/pingora-core/src/server/bootstrap_services.rs +++ b/pingora-core/src/server/bootstrap_services.rs @@ -18,7 +18,7 @@ use async_trait::async_trait; use log::{debug, error, info}; use parking_lot::Mutex; use std::sync::Arc; -use tokio::sync::{broadcast, Mutex as TokioMutex}; +use tokio::sync::broadcast; #[cfg(feature = "sentry")] use sentry::ClientOptions; @@ -95,7 +95,7 @@ impl Bootstrap { upgrade, upgrade_sock, #[cfg(unix)] - listen_fds: Arc::new(TokioMutex::new(Fds::new())), + listen_fds: Arc::new(Mutex::new(Fds::new())), execution_phase_watch: execution_phase_watch.clone(), completed: false, #[cfg(feature = "sentry")] @@ -191,7 +191,7 @@ impl Bootstrap { let mut fds = Fds::new(); fds.get_from_sock(self.upgrade_sock.as_str())?; // Mutate through the existing Arc so all clones held by services see the update. - *self.listen_fds.blocking_lock() = fds; + *self.listen_fds.lock() = fds; } Ok(()) } diff --git a/pingora-core/src/server/mod.rs b/pingora-core/src/server/mod.rs index 1f520f7a..ffedf665 100644 --- a/pingora-core/src/server/mod.rs +++ b/pingora-core/src/server/mod.rs @@ -36,7 +36,7 @@ use std::thread; use std::time::SystemTime; #[cfg(unix)] use tokio::signal::unix; -use tokio::sync::{broadcast, watch, Mutex as TokioMutex}; +use tokio::sync::{broadcast, watch}; use tokio::time::{sleep, Duration}; use crate::prelude::background_service; @@ -117,7 +117,7 @@ pub enum ExecutionPhase { /// to shutdown pub type ShutdownWatch = watch::Receiver; #[cfg(unix)] -pub type ListenFds = Arc>; +pub type ListenFds = Arc>; /// The type of shutdown process that has been requested. #[derive(Debug)] @@ -272,24 +272,28 @@ impl Server { .send(ExecutionPhase::GracefulUpgradeTransferringFds) .ok(); - let fds = self.listen_fds(); - let fds = fds.lock().await; - if fds.is_empty() { - info!("No socks to send, shutting down."); - } else { - info!("Trying to send socks"); - // XXX: this is blocking IO - match fds.send_to_sock(self.configuration.as_ref().upgrade_sock.as_str()) { - Ok(_) => { - info!("listener sockets sent"); - } - Err(e) => { - error!("Unable to send listener sockets to new process: {e}"); - // sentry log error on fd send failure - #[cfg(all(not(debug_assertions), feature = "sentry"))] - sentry::capture_error(&e); + let sent_fds = { + let fds = self.listen_fds(); + let fds = fds.lock(); + if fds.is_empty() { + info!("No socks to send, shutting down."); + false + } else { + info!("Trying to send socks"); + match fds.send_to_sock(self.configuration.as_ref().upgrade_sock.as_str()) { + Ok(_) => { + info!("listener sockets sent"); + } + Err(e) => { + error!("Unable to send listener sockets to new process: {e}"); + #[cfg(all(not(debug_assertions), feature = "sentry"))] + sentry::capture_error(&e); + } } + true } + }; + if sent_fds { self.execution_phase_watch .send(ExecutionPhase::GracefulUpgradeCloseTimeout) .ok(); From 9855feb57c6864caf6cb0e9cf5bbf6362de8c1d1 Mon Sep 17 00:00:00 2001 From: zaidoon Date: Fri, 10 Apr 2026 14:01:45 -0400 Subject: [PATCH 12/52] ci: use cargo check for MSRV instead of cargo test The MSRV (1.84.0) job fails because cargo test compiles dev-dependencies. A transitive dev-dependency chain (pingora-proxy -> tokio-tungstenite -> tungstenite -> sha1 -> cpufeatures v0.3.0) pulls in a crate that uses edition 2024, which Cargo 1.84.0 cannot parse. Run cargo check --workspace for all toolchains and skip cargo test on the MSRV. --- .github/workflows/build.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 22a4c458..adf2b014 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -38,12 +38,17 @@ jobs: - name: Run cargo fmt run: cargo fmt --all -- --check + - name: Run cargo check + run: cargo check --workspace + - name: Run cargo test + if: matrix.toolchain != '1.84.0' run: cargo test --verbose --lib --bins --tests --no-fail-fast # Need to run doc tests separately. # (https://github.com/rust-lang/cargo/issues/6669) - name: Run cargo doc test + if: matrix.toolchain != '1.84.0' run: cargo test --verbose --doc - name: Run cargo clippy From ee387f4ab1ba00fe28f21332e984fcbe430c85b7 Mon Sep 17 00:00:00 2001 From: Kevin Guthrie Date: Sat, 28 Mar 2026 21:23:15 -0400 Subject: [PATCH 13/52] Add a mechanism for signalling between old and new processes when doing graceful upgrades --- .bleep | 2 +- docs/user_guide/conf.md | 3 + pingora-core/src/server/bootstrap_services.rs | 44 ++- pingora-core/src/server/configuration/mod.rs | 46 +++ pingora-core/src/server/daemon.rs | 349 +++++++++++++++++- pingora-core/src/server/mod.rs | 7 +- pingora-core/tests/bootstrap_as_a_service.rs | 136 +++++++ pingora/examples/graceful_upgrade.rs | 186 ++++++++++ 8 files changed, 759 insertions(+), 14 deletions(-) create mode 100644 pingora-core/tests/bootstrap_as_a_service.rs create mode 100644 pingora/examples/graceful_upgrade.rs diff --git a/.bleep b/.bleep index 2911a4d3..a5fdccd2 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -67cc768b717d3865a73bcd917c905d7d9aeb4c62 +3f438f804b9e954f7b882815fc9e70ebe9d572a0 \ No newline at end of file diff --git a/docs/user_guide/conf.md b/docs/user_guide/conf.md index 1f55859e..70a8f569 100644 --- a/docs/user_guide/conf.md +++ b/docs/user_guide/conf.md @@ -29,6 +29,9 @@ group: webusers | s2n_config_cache_size | The maximum number of unique s2n configs to cache. A value of 0 disables the cache. Default: 10 (s2n-tls only) | number | | work_stealing | Enable work stealing runtime (default true). See Pingora runtime (WIP) section for more info | bool | | upstream_keepalive_pool_size | The number of total connections to keep in the connection pool | number | +| daemon_wait_for_ready | When `true` and `daemon` is `true`, the parent process waits for the daemon to signal readiness (via `SIGUSR1`) before exiting. This causes systemd to delay sending `SIGQUIT` to the old process until the new instance is fully bootstrapped. Default: `false` | bool | +| daemon_ready_timeout_seconds | How long (in seconds) the parent waits for the daemon to signal readiness when `daemon_wait_for_ready` is `true`. If the daemon does not signal in time the parent exits with a non-zero code, causing systemd to abort the reload. Default: `600` | number | +| daemon_notify_timeout_seconds | How long (in seconds) the daemon retries sending `SIGUSR1` to the parent when the attempt fails with a permission error. This covers the brief window after the fork where the parent has not yet dropped its UID to match the daemon. Default: `60` | number | ## Extension Any unknown settings will be ignored. This allows extending the conf file to add and pass user defined settings. See User defined configuration section. diff --git a/pingora-core/src/server/bootstrap_services.rs b/pingora-core/src/server/bootstrap_services.rs index ca5bfad0..74c81d79 100644 --- a/pingora-core/src/server/bootstrap_services.rs +++ b/pingora-core/src/server/bootstrap_services.rs @@ -23,8 +23,12 @@ use tokio::sync::broadcast; #[cfg(feature = "sentry")] use sentry::ClientOptions; +#[cfg(unix)] +use crate::server::daemon::notify_parent_ready_for_fds; #[cfg(unix)] use crate::server::ListenFds; +#[cfg(unix)] +use std::time::Duration; use crate::{ prelude::Opt, @@ -32,6 +36,10 @@ use crate::{ services::{background::BackgroundService, ServiceReadyNotifier}, }; +/// Default timeout for retrying `SIGUSR1` to the parent when it fails with `EPERM`. +#[cfg(unix)] +const DEFAULT_DAEMON_NOTIFY_TIMEOUT: Duration = Duration::from_secs(60); + /// Service that allows the bootstrap process to be delayed until after /// dependencies are ready pub struct BootstrapService { @@ -60,6 +68,16 @@ pub struct Bootstrap { #[cfg(unix)] listen_fds: ListenFds, + /// PID of the original parent process to notify via `SIGUSR1` after bootstrap completes. + /// Set when [`ServerConf::daemon_wait_for_ready`] is `true`. + #[cfg(unix)] + notify_parent_pid: Option, + + /// How long to keep retrying `SIGUSR1` to the parent when it fails with `EPERM`. + /// See [`ServerConf::daemon_notify_timeout_seconds`]. + #[cfg(unix)] + daemon_notify_timeout: std::time::Duration, + #[cfg(feature = "sentry")] #[cfg_attr(docsrs, doc(cfg(feature = "sentry")))] /// The Sentry ClientOptions. @@ -96,6 +114,13 @@ impl Bootstrap { upgrade_sock, #[cfg(unix)] listen_fds: Arc::new(Mutex::new(Fds::new())), + #[cfg(unix)] + notify_parent_pid: None, + #[cfg(unix)] + daemon_notify_timeout: conf + .daemon_notify_timeout_seconds + .map(|n| Duration::from_secs(n.get())) + .unwrap_or(DEFAULT_DAEMON_NOTIFY_TIMEOUT), execution_phase_watch: execution_phase_watch.clone(), completed: false, #[cfg(feature = "sentry")] @@ -110,6 +135,13 @@ impl Bootstrap { self.sentry = sentry_config; } + /// Store the parent process PID to notify via `SIGUSR1` after bootstrap completes. + /// Only relevant when [`ServerConf::daemon_wait_for_ready`] is `true`. + #[cfg(unix)] + pub fn set_notify_parent_pid(&mut self, pid: u32) { + self.notify_parent_pid = Some(pid); + } + /// Initialize the Sentry client from the configured [`ClientOptions`] and /// store the resulting guard. /// @@ -161,7 +193,17 @@ impl Bootstrap { std::process::exit(0); } - // load fds + // Notify the parent process that it can exit. It might seem like we should load the file + // descriptors from the old process first, but the purpose of this notification is to + // release the parent so that the process managing it (e.g. systemd) can continue and send + // a quit signal to the old process. That quit signal is required before the old process + // will start trying to send its file descriptors to us — so if we called load_fds first, + // we would be guaranteeing a timeout. + #[cfg(unix)] + if let Some(pid) = self.notify_parent_pid { + notify_parent_ready_for_fds(pid, self.daemon_notify_timeout); + } + #[cfg(unix)] match self.load_fds(self.upgrade) { Ok(_) => { diff --git a/pingora-core/src/server/configuration/mod.rs b/pingora-core/src/server/configuration/mod.rs index 8ab02bf3..1f410892 100644 --- a/pingora-core/src/server/configuration/mod.rs +++ b/pingora-core/src/server/configuration/mod.rs @@ -25,6 +25,7 @@ use pingora_error::{Error, ErrorType::*, OrErr, Result}; use serde::{Deserialize, Serialize}; use std::ffi::OsString; use std::fs; +use std::num::NonZeroU64; // default maximum upstream retries for retry-able proxy errors const DEFAULT_MAX_RETRIES: usize = 16; @@ -125,6 +126,45 @@ pub struct ServerConf { /// /// When not set, the tokio default (10 seconds) is used. pub blocking_threads_ttl_seconds: Option, + /// When `daemon` is `true`, controls whether the parent process of the daemon fork waits for + /// the child to signal readiness before exiting. + /// + /// When `false` (default), the parent exits immediately after the daemon fork, matching the + /// traditional daemonization behavior. Systemd will consider the service started as soon as + /// the parent exits, which may be before the child has finished bootstrapping. + /// + /// When `true`, the parent waits (up to [`Self::daemon_ready_timeout_seconds`]) for the child + /// to send `SIGUSR1` after bootstrap completes. This causes systemd to delay any subsequent + /// steps (such as sending `SIGQUIT` to the old process) until the new instance is fully ready + /// to serve traffic. If the child does not signal in time, the parent exits with a non-zero + /// exit code, causing systemd to abort the reload. + pub daemon_wait_for_ready: bool, + /// Timeout in seconds for the parent process to wait for the child to signal readiness during + /// daemonization when [`Self::daemon_wait_for_ready`] is `true`. + /// + /// If the child does not send `SIGUSR1` within this timeout, the parent exits with a non-zero + /// exit code. + /// + /// Defaults to 600 seconds (10 minutes). + pub daemon_ready_timeout_seconds: Option, + /// How long the child process will keep retrying `SIGUSR1` to the parent when the signal + /// fails with a permission error (`EPERM`) during daemonization. + /// + /// After the daemon fork, the parent always drops its credentials to the configured user and + /// group (see [`Self::user`], [`Self::group`]). Because the privilege drop happens after the + /// fork, there is a small window where the child may attempt to signal the parent before the + /// parent has finished changing its credentials. During this window the kernel will reject the + /// signal with `EPERM` because the child and parent are running as different users. The child + /// retries every 100 ms until this timeout elapses. + /// + /// In practice this window is very small, so the default of 60 seconds is far more than + /// enough to account for it. + /// + /// Only retries on `EPERM`; any other error (e.g. `ESRCH` — parent no longer exists) is + /// treated as fatal and logged without retrying. + /// + /// Defaults to 60 seconds. + pub daemon_notify_timeout_seconds: Option, } impl Default for ServerConf { @@ -155,6 +195,9 @@ impl Default for ServerConf { upgrade_sock_connect_accept_max_retries: None, max_blocking_threads: None, blocking_threads_ttl_seconds: None, + daemon_ready_timeout_seconds: None, + daemon_wait_for_ready: false, + daemon_notify_timeout_seconds: None, } } } @@ -326,6 +369,9 @@ mod tests { upgrade_sock_connect_accept_max_retries: None, max_blocking_threads: None, blocking_threads_ttl_seconds: None, + daemon_ready_timeout_seconds: None, + daemon_wait_for_ready: false, + daemon_notify_timeout_seconds: None, }; // cargo test -- --nocapture not_a_test_i_cannot_write_yaml_by_hand println!("{}", conf.to_yaml()); diff --git a/pingora-core/src/server/daemon.rs b/pingora-core/src/server/daemon.rs index 7381fc93..b6c95cb0 100644 --- a/pingora-core/src/server/daemon.rs +++ b/pingora-core/src/server/daemon.rs @@ -12,18 +12,71 @@ // See the License for the specific language governing permissions and // limitations under the License. -use daemonize::{Daemonize, Stdio}; -use log::{debug, error}; +use daemonize::{Daemonize, Outcome, Stdio}; +use log::{debug, error, info}; +use pingora_error::{Error, ErrorType, OrErr, Result}; use std::ffi::CString; use std::fs::{self, OpenOptions}; use std::os::unix::prelude::OpenOptionsExt; use std::path::Path; +use std::process; +use std::thread; +use std::time::{Duration, Instant}; use crate::server::configuration::ServerConf; +/// Error returned by [`send_signal`]. +#[derive(Debug)] +pub(crate) enum SignalError { + /// The caller does not have permission to send the signal to the target process (`EPERM`). + PermissionDenied, + /// Any other error from `kill(2)`. Contains the raw `errno` value. + OtherSignalError(i32), +} + +impl std::fmt::Display for SignalError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SignalError::PermissionDenied => write!(f, "permission denied (EPERM)"), + SignalError::OtherSignalError(errno) => { + write!(f, "kill failed with errno {errno}") + } + } + } +} + +/// Send `signal` to the process identified by `pid`. +/// +/// Returns `Ok(())` on success. On failure, maps `errno` to [`SignalError`]: +/// - `EPERM` → [`SignalError::PermissionDenied`] +/// - anything else → [`SignalError::OtherSignalError`] containing the raw errno value. +fn send_signal(pid: libc::pid_t, signal: libc::c_int) -> Result<(), SignalError> { + // SAFETY: `kill(2)` is safe to call with any pid/signal combination — invalid values + // simply return an error via errno rather than causing undefined behavior. + let ret = unsafe { libc::kill(pid, signal) }; + if ret == 0 { + return Ok(()); + } + let errno = std::io::Error::last_os_error().raw_os_error().unwrap_or(-1); + if errno == libc::EPERM { + Err(SignalError::PermissionDenied) + } else { + Err(SignalError::OtherSignalError(errno)) + } +} + // Utilities to daemonize a pingora server, i.e. run the process in the background, possibly // under a different running user and/or group. +/// Default timeout for the parent to wait for the daemon child to signal readiness. +const DEFAULT_DAEMON_READY_TIMEOUT: Duration = Duration::from_secs(600); + +/// How long to sleep between `SIGUSR1` send attempts when `EPERM` is returned. +const NOTIFY_RETRY_INTERVAL: Duration = Duration::from_millis(100); + +/// How long to sleep between pid-file liveness checks in the async wait loop. +const LIVENESS_CHECK_INTERVAL: Duration = Duration::from_millis(100); + // XXX: this operation should have been done when the old service is exiting. // Now the new pid file just kick the old one out of the way fn move_old_pid(path: &str) { @@ -45,7 +98,15 @@ fn move_old_pid(path: &str) { } } +/// # Safety +/// +/// `name` must be a valid, null-terminated C string. The returned `gid_t` is read from the +/// `passwd` struct returned by `getpwnam(3)`, which points to a static buffer that may be +/// overwritten by subsequent calls to `getpwnam` or `getpwuid`. The caller must not hold the +/// pointer across such calls. unsafe fn gid_for_username(name: &CString) -> Option { + // SAFETY: `name` is a valid CString; `getpwnam` returns a pointer to a static buffer + // or null. We read `pw_gid` immediately and do not retain the pointer. let passwd = libc::getpwnam(name.as_ptr() as *const libc::c_char); if !passwd.is_null() { return Some((*passwd).pw_gid); @@ -53,9 +114,277 @@ unsafe fn gid_for_username(name: &CString) -> Option { None } +/// Drop the parent process's UID to the user specified in [`ServerConf::user`]. +/// +/// The kernel only permits a process to send a signal to another if they share the same UID (or +/// the sender is root). Since the daemon child sends `SIGUSR1` to the parent to signal readiness, +/// the parent must be running as the same UID as the child by the time that signal arrives — +/// otherwise the kernel will reject it with `EPERM`. +/// +/// This function is called in the `Outcome::Parent` path immediately after `execute()` returns, +/// before the parent enters its readiness wait loop, so the parent's UID matches the child's as +/// quickly as possible after the fork. +/// +/// Only the UID is changed; the GID is left as-is. Signal permission checks are based on UID, +/// so changing the GID is not necessary for this purpose. +/// +/// Logs an error and continues if the user cannot be resolved or `setuid` fails — the parent +/// is short-lived and about to exit, so a failed privilege drop is non-fatal. The child's +/// `EPERM` retry window (see [`ServerConf::daemon_notify_timeout_seconds`]) exists precisely to +/// cover the small gap between the fork and the parent completing this UID change. +fn drop_privileges_in_parent(conf: &ServerConf) -> Result<()> { + let Some(user) = conf.user.as_ref() else { + return Ok(()); + }; + + let user_cstr = CString::new(user.as_str()).or_err_with(ErrorType::Custom("Daemon"), || { + format!("drop_privileges_in_parent: user '{user}' invalid") + })?; + + // SAFETY: `user_cstr` is a valid CString. `getpwnam` returns a pointer to a static + // buffer or null. We read `pw_uid` immediately and do not retain the pointer. + let passwd = unsafe { libc::getpwnam(user_cstr.as_ptr() as *const libc::c_char) }; + if passwd.is_null() { + return Error::e_explain( + ErrorType::Custom("Daemon"), + format!("drop_privileges_in_parent: user '{user}' not found"), + ); + } + + // SAFETY: `passwd` was checked for null above. We dereference it once to read `pw_uid`. + let uid = unsafe { (*passwd).pw_uid }; + // SAFETY: `setuid(2)` is safe to call with any uid — invalid values return an error. + let ret = unsafe { libc::setuid(uid) }; + if ret == 0 { + Ok(()) + } else { + Error::e_explain( + ErrorType::Custom("Daemon"), + format!( + "drop_privileges_in_parent: setuid({uid}) failed: {}", + std::io::Error::last_os_error() + ), + ) + } +} + +/// Outcome of calling [`daemonize`]. +/// +/// When [`ServerConf::daemon_wait_for_ready`] is `true`, the child process must call +/// [`notify_parent_ready_for_fds`] after bootstrap completes to unblock the parent's wait loop. +pub struct DaemonizeResult { + /// The PID of the original parent process to notify via `SIGUSR1` after bootstrap completes. + /// + /// `Some` when [`ServerConf::daemon_wait_for_ready`] is `true`, `None` otherwise. + pub notify_parent_pid: Option, +} + /// Start a server instance as a daemon. -#[cfg(unix)] -pub fn daemonize(conf: &ServerConf) { +/// +/// Both code paths use [`daemonize::Daemonize::execute()`] rather than calling `fork()` directly. +/// `execute()` returns an [`Outcome`] to the caller in each process rather than having the parent +/// exit inside the crate, which gives us the opportunity to run additional logic in the parent +/// before it exits. +/// +/// When [`ServerConf::daemon_wait_for_ready`] is `false` (the default), the parent exits +/// immediately — matching the behavior of `start()`. +/// +/// When `daemon_wait_for_ready` is `true`, the parent registers a `SIGUSR1` handler before +/// forking, then waits (in a sleep loop polling the pid file and the signal flag) for up to +/// [`ServerConf::daemon_ready_timeout_seconds`] (default 600 s) for the grandchild to send +/// `SIGUSR1`. On success the parent exits with code 0. On timeout, or if the daemon process +/// exits before signaling, the parent exits with code 1, causing systemd to abort the reload. +/// +/// Returns a [`DaemonizeResult`] that is only meaningful to the child process. The parent always +/// exits before returning. +pub fn daemonize(conf: &ServerConf) -> DaemonizeResult { + // Capture the parent PID before forking so it can be passed to the grandchild. The + // grandchild sends SIGUSR1 to this PID after bootstrap completes. + let parent_pid = if conf.daemon_wait_for_ready { + Some(process::id()) + } else { + None + }; + + move_old_pid(&conf.pid_file); + + match build_daemonize(conf).execute() { + Outcome::Parent(result) => { + result.unwrap_or_else(|e| panic!("Daemonize failed: {e}")); + } + Outcome::Child(result) => { + result.unwrap_or_else(|e| panic!("Daemonize child setup failed: {e}")); + return DaemonizeResult { + notify_parent_pid: parent_pid, + }; + } + } + + if conf.daemon_wait_for_ready { + // Drop root privileges before waiting so the parent does not linger as root. + if let Err(e) = drop_privileges_in_parent(conf) { + error!("drop_privileges_in_parent failed: {e}"); + + // Exiting the parent process should be fine because if downgrading + // the user's privileges fails here, it will fail in the child and + // the child will exit too + process::exit(1); + } + + let timeout = conf + .daemon_ready_timeout_seconds + .map(|n| Duration::from_secs(n.get())) + .unwrap_or(DEFAULT_DAEMON_READY_TIMEOUT); + + info!( + "Waiting up to {:?} for daemon to signal readiness via SIGUSR1", + timeout + ); + + wait_for_ready_or_exit(&conf.pid_file, timeout); + } + + process::exit(0); +} + +/// Build a single-threaded tokio runtime for the parent's signal wait loop. +/// +/// The parent process is short-lived and only needs to wait for a signal and check the pid file. +/// A current-thread runtime avoids spawning worker threads in a process that is about to exit. +fn build_parent_runtime() -> tokio::runtime::Runtime { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("failed to build tokio runtime for parent signal wait") +} + +/// Wait for the daemon grandchild to send `SIGUSR1`, up to `timeout`. +/// +/// Uses a local tokio runtime with [`tokio::signal::unix`] to listen for `SIGUSR1` instead of +/// raw signal handlers and polling loops. The daemon's PID is checked periodically via the pid +/// file — if the process exits before signaling, the parent aborts. +/// +/// Exits the process directly: +/// - exit code 0 if `SIGUSR1` is received (daemon is ready). +/// - exit code 1 if `timeout` elapses (daemon took too long). +/// - exit code 1 if the pid file exists and the process is no longer running. +fn wait_for_ready_or_exit(pid_file: &str, timeout: Duration) { + let rt = build_parent_runtime(); + let pid_file = pid_file.to_owned(); + + rt.block_on(async move { + use tokio::signal::unix::{signal, SignalKind}; + use tokio::time::{interval, timeout as tokio_timeout}; + + let mut sigusr1 = + signal(SignalKind::user_defined1()).expect("failed to register SIGUSR1 listener"); + + let mut liveness_check = interval(LIVENESS_CHECK_INTERVAL); + let mut daemon_pid: Option = None; + + let result = tokio_timeout(timeout, async { + loop { + tokio::select! { + _ = sigusr1.recv() => { + info!("Daemon signaled readiness, parent exiting"); + return; + } + _ = liveness_check.tick() => { + if daemon_pid.is_none() { + daemon_pid = try_read_pid_file(&pid_file); + } + if let Some(pid) = daemon_pid { + if !process_is_running(pid) { + error!( + "Daemon process (pid {pid}) is no longer running \ + before signaling readiness, aborting" + ); + process::exit(1); + } + } + } + } + } + }) + .await; + + if result.is_err() { + error!("Daemon did not signal readiness within {timeout:?}, aborting"); + process::exit(1); + } + }); +} + +/// Notify the parent process that the daemon is ready to serve traffic by sending `SIGUSR1`. +/// +/// Should be called by the daemon process after bootstrap is complete when +/// [`ServerConf::daemon_wait_for_ready`] is `true`. `parent_pid` is the PID of the original +/// process captured before the fork and stored in [`DaemonizeResult::notify_parent_pid`]. +/// +/// `SIGUSR1` sets an atomic flag that the parent's wait loop checks, causing it to exit with +/// code 0 and allowing systemd to proceed with the next step of the service reload. +/// +/// If `kill(2)` returns `EPERM` — which can happen transiently when the child's UID has just +/// been changed by `setuid` and the kernel hasn't yet updated the credential check — the +/// function sleeps for [`NOTIFY_RETRY_INTERVAL`] (100 ms) and retries until `notify_timeout` +/// elapses, at which point it logs an error and returns. Any other error (e.g. `ESRCH`, +/// meaning the parent no longer exists) is logged and the function returns immediately without +/// retrying. +pub fn notify_parent_ready_for_fds(parent_pid: u32, notify_timeout: Duration) { + let parent_pid = parent_pid as libc::pid_t; + info!( + "Sending SIGUSR1 to parent process (pid {}) to signal daemon readiness", + parent_pid + ); + + let start = Instant::now(); + + while start.elapsed() < notify_timeout { + match send_signal(parent_pid, libc::SIGUSR1) { + Ok(()) => return, + Err(SignalError::PermissionDenied) => { + debug!( + "Permission denied sending SIGUSR1 to parent (pid {}), retrying in {:?}", + parent_pid, NOTIFY_RETRY_INTERVAL + ); + thread::sleep(NOTIFY_RETRY_INTERVAL); + } + Err(SignalError::OtherSignalError(errno)) => { + error!( + "Failed to send SIGUSR1 to parent (pid {}): errno {errno}", + parent_pid + ); + return; + } + } + } + + error!( + "Permission denied sending SIGUSR1 to parent (pid {}), giving up after {:?}", + parent_pid, notify_timeout + ); +} + +/// Try to read a PID from `pid_file`. Returns `None` if the file does not exist or cannot be +/// parsed. +fn try_read_pid_file(pid_file: &str) -> Option { + fs::read_to_string(pid_file) + .ok() + .and_then(|c| c.trim().parse().ok()) +} + +/// Returns `true` if a process with `pid` is currently running. +fn process_is_running(pid: libc::pid_t) -> bool { + // Signal 0 does not send a signal; it just checks whether the process exists and whether + // we have permission to signal it. EPERM (no permission) is not possible here because + // drop_privileges_in_parent guarantees the parent has already dropped to the same user as + // the daemon child before this function is called. + send_signal(pid, 0).is_ok() +} + +/// Build a [`Daemonize`] instance configured from `conf`, without calling `start()` or +/// `execute()`. The caller is responsible for driving execution. +fn build_daemonize(conf: &ServerConf) -> Daemonize<()> { // TODO: customize working dir let daemonize = Daemonize::new() @@ -82,6 +411,7 @@ pub fn daemonize(conf: &ServerConf) { Some(user) => { let user_cstr = CString::new(user.as_str()).unwrap(); + // SAFETY: `user_cstr` is a valid CString. See `gid_for_username` safety docs. #[cfg(target_os = "macos")] let group_id = unsafe { gid_for_username(&user_cstr).map(|gid| gid as i32) }; #[cfg(target_os = "freebsd")] @@ -92,7 +422,8 @@ pub fn daemonize(conf: &ServerConf) { daemonize .privileged_action(move || { if let Some(gid) = group_id { - // Set the supplemental group privileges for the child process. + // SAFETY: `user_cstr` is a valid CString captured by the closure. + // `initgroups(3)` is safe to call with a valid username and gid. unsafe { libc::initgroups(user_cstr.as_ptr() as *const libc::c_char, gid); } @@ -104,12 +435,8 @@ pub fn daemonize(conf: &ServerConf) { None => daemonize, }; - let daemonize = match conf.group.as_ref() { + match conf.group.as_ref() { Some(group) => daemonize.group(group.as_str()), None => daemonize, - }; - - move_old_pid(&conf.pid_file); - - daemonize.start().unwrap(); // hard crash when fail + } } diff --git a/pingora-core/src/server/mod.rs b/pingora-core/src/server/mod.rs index ffedf665..0d3a105e 100644 --- a/pingora-core/src/server/mod.rs +++ b/pingora-core/src/server/mod.rs @@ -624,8 +624,13 @@ impl Server { if conf.daemon { info!("Daemonizing the server"); fast_timeout::pause_for_fork(); - daemonize(&self.configuration); + let daemonize_result = daemonize(&self.configuration); fast_timeout::unpause(); + // If daemon_wait_for_ready is enabled, pass the parent PID to bootstrap so it + // can send SIGUSR1 to the parent after bootstrap completes. + if let Some(pid) = daemonize_result.notify_parent_pid { + self.bootstrap.lock().set_notify_parent_pid(pid); + } } #[cfg(windows)] diff --git a/pingora-core/tests/bootstrap_as_a_service.rs b/pingora-core/tests/bootstrap_as_a_service.rs new file mode 100644 index 00000000..fa88ff20 --- /dev/null +++ b/pingora-core/tests/bootstrap_as_a_service.rs @@ -0,0 +1,136 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Integration tests for `bootstrap_as_a_service`. +//! +//! Verifies that when `bootstrap_as_a_service()` dependencies are declared, the +//! `BootstrapComplete` execution phase is not reached until all dependency services have +//! finished their initialization work. + +use async_trait::async_trait; +use pingora_core::server::ShutdownWatch; +use pingora_core::server::{configuration::ServerConf, ExecutionPhase, RunArgs, Server}; +use pingora_core::services::background::{background_service, BackgroundService}; +use pingora_core::services::ServiceReadyNotifier; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +/// A background service that sets a flag when it completes, after an optional delay. +/// +/// Signals readiness only after completing its initialization work so that dependent +/// services (like `BootstrapService`) cannot start until this service is truly done. +struct TrackableService { + delay: Duration, + completed: Arc, +} + +#[async_trait] +impl BackgroundService for TrackableService { + async fn start_with_ready_notifier( + &self, + _shutdown: ShutdownWatch, + ready_notifier: ServiceReadyNotifier, + ) { + if !self.delay.is_zero() { + tokio::time::sleep(self.delay).await; + } + self.completed.store(true, Ordering::SeqCst); + // Signal readiness only after work is done — this is what the dependency + // mechanism waits on before allowing BootstrapService to proceed. + ready_notifier.notify_ready(); + } +} + +/// Verifies that `bootstrap_as_a_service` does not reach `BootstrapComplete` until all +/// declared dependency services have finished their initialization work. +#[test] +fn test_bootstrap_waits_for_dependencies() { + let conf = ServerConf { + grace_period_seconds: Some(1), + graceful_shutdown_timeout_seconds: Some(1), + ..Default::default() + }; + + let mut server = Server::new_with_opt_and_conf(None, conf); + let mut phase = server.watch_execution_phase(); + + // Two dependency services with delays. The second (150 ms) sets the pace. + let dep1_done = Arc::new(AtomicBool::new(false)); + let dep2_done = Arc::new(AtomicBool::new(false)); + + let dep1_handle = server.add_service(background_service( + "dep1", + TrackableService { + delay: Duration::from_millis(50), + completed: dep1_done.clone(), + }, + )); + let dep2_handle = server.add_service(background_service( + "dep2", + TrackableService { + delay: Duration::from_millis(150), + completed: dep2_done.clone(), + }, + )); + + // BootstrapService must not reach BootstrapComplete until dep1 and dep2 are done. + let bootstrap_handle = server.bootstrap_as_a_service(); + bootstrap_handle.add_dependencies([&dep1_handle, &dep2_handle]); + + // When using bootstrap_as_a_service, do NOT call server.bootstrap() separately — + // the BootstrapService runs as a background service during run(), and emits + // BootstrapComplete only after all its declared dependencies are ready. + let _join = std::thread::spawn(move || { + server.run(RunArgs::default()); + }); + + let mut received_bootstrap = false; + let mut received_bootstrap_complete = false; + + // Collect phases until BootstrapComplete is seen. Running may arrive + // before or after Bootstrap/BootstrapComplete since main_loop starts + // concurrently with the service runtimes. + loop { + match phase.blocking_recv() { + Ok(ExecutionPhase::Bootstrap) => { + received_bootstrap = true; + } + Ok(ExecutionPhase::BootstrapComplete) => { + // Both dependencies must have set their flags before bootstrap completes. + assert!( + dep1_done.load(Ordering::SeqCst), + "dep1 should be done before BootstrapComplete" + ); + assert!( + dep2_done.load(Ordering::SeqCst), + "dep2 should be done before BootstrapComplete" + ); + received_bootstrap_complete = true; + break; + } + Ok(_) => {} + Err(_) => break, + } + } + + assert!(received_bootstrap, "should have seen Bootstrap phase"); + assert!( + received_bootstrap_complete, + "should have seen BootstrapComplete phase" + ); + + // Shut down cleanly. + std::process::exit(0); +} diff --git a/pingora/examples/graceful_upgrade.rs b/pingora/examples/graceful_upgrade.rs new file mode 100644 index 00000000..5a64ff7f --- /dev/null +++ b/pingora/examples/graceful_upgrade.rs @@ -0,0 +1,186 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! # Graceful Upgrade Example +//! +//! Demonstrates the `daemon_wait_for_ready` feature, which coordinates graceful process upgrades +//! by ensuring the new process is fully bootstrapped before the old one begins shutting down. +//! +//! ## Background +//! +//! In a standard daemonized pingora service, the parent process exits immediately after the +//! daemon fork. During a graceful upgrade, the process manager sends SIGQUIT to the old process +//! as soon as the new process's parent exits — potentially before the new process has finished +//! initializing its backends, consistent hash rings, or other state. This can cause a brief +//! window of 502s. +//! +//! With `daemon_wait_for_ready = true`, the parent instead waits for the daemon to send SIGUSR1 +//! before exiting. The process manager only proceeds to stop the old process once the new one +//! signals that it is ready to serve traffic. +//! +//! ## Service startup order +//! +//! This example sets up the following dependency chain: +//! +//! ```text +//! BackendDiscoveryService HashRingService +//! \ / +//! \ / +//! BootstrapService (socket transfer + SIGUSR1 to parent) +//! ``` +//! +//! The bootstrap service — which handles transferring listening sockets from the old process and +//! sending SIGUSR1 to the parent to signal readiness — only runs after both slow initialization +//! services have completed. This ensures the parent never exits until the new process is truly +//! ready to serve traffic. +//! +//! ## Usage +//! +//! ```bash +//! # Run interactively (no daemonization) +//! cargo run --example graceful_upgrade -p pingora +//! +//! # Run as a daemon +//! cargo run --example graceful_upgrade -p pingora -- -d +//! +//! # Graceful upgrade of a running daemon instance +//! cargo run --example graceful_upgrade -p pingora -- -d -u +//! ``` + +use async_trait::async_trait; +use bytes::Bytes; +use clap::Parser; +use http::{Response, StatusCode}; +use log::info; +use std::num::NonZeroU64; +use std::time::Duration; +use tokio::time::sleep; + +use pingora::apps::http_app::ServeHttp; +use pingora::prelude::Opt; +use pingora::protocols::http::ServerSession; +use pingora::server::configuration::ServerConf; +use pingora::server::{Server, ShutdownWatch}; +use pingora::services::background::{background_service, BackgroundService}; +use pingora::services::listening::Service as ListeningService; + +/// Simulates slow backend discovery — e.g. resolving upstream endpoints from a service registry. +pub struct BackendDiscoveryService; + +#[async_trait] +impl BackgroundService for BackendDiscoveryService { + async fn start(&self, _shutdown: ShutdownWatch) { + info!("BackendDiscoveryService: discovering backends..."); + sleep(Duration::from_secs(2)).await; + info!("BackendDiscoveryService: backends ready"); + } +} + +/// Simulates slow consistent hash ring construction. Runs in parallel with +/// `BackendDiscoveryService`; bootstrap waits for both to complete. +pub struct HashRingService; + +#[async_trait] +impl BackgroundService for HashRingService { + async fn start(&self, _shutdown: ShutdownWatch) { + info!("HashRingService: building consistent hash ring..."); + sleep(Duration::from_secs(3)).await; + info!("HashRingService: hash ring ready"); + } +} + +/// A minimal HTTP service that responds to every request with 200 OK. +/// +/// Accepts an optional `sleep` query parameter specifying how many seconds to wait before +/// responding (e.g. `GET /?sleep=20`). This makes in-flight requests easy to observe during a +/// graceful upgrade: a request with a long sleep that arrives just before the upgrade begins will +/// still be running when the new process starts up, demonstrating that the old process keeps +/// serving until all connections are drained. +pub struct HelloApp; + +#[async_trait] +impl ServeHttp for HelloApp { + async fn response(&self, http_stream: &mut ServerSession) -> Response> { + let delay_secs = http_stream + .req_header() + .uri + .query() + .and_then(|q| { + q.split('&').find_map(|pair| { + let (key, val) = pair.split_once('=')?; + if key == "sleep" { + val.parse::().ok() + } else { + None + } + }) + }) + .unwrap_or(0); + + if delay_secs > 0 { + sleep(Duration::from_secs(delay_secs)).await; + } + + let body = Bytes::from("hello from graceful_upgrade example\n"); + Response::builder() + .status(StatusCode::OK) + .header(http::header::CONTENT_TYPE, "text/plain") + .header(http::header::CONTENT_LENGTH, body.len()) + .body(body.to_vec()) + .unwrap() + } +} + +fn main() { + env_logger::init(); + + let opt = Some(Opt::parse()); + + // Build a ServerConf with daemon_wait_for_ready enabled. + // + // When the server is started with -d (daemon mode), the parent process waits for SIGUSR1 + // before exiting. The daemon sends SIGUSR1 only after the bootstrap service completes — + // which in this example means after both slow services have signaled readiness. + let conf = ServerConf { + daemon: true, + daemon_wait_for_ready: true, + daemon_ready_timeout_seconds: NonZeroU64::new(60), + ..ServerConf::default() + }; + + let mut server = Server::new_with_opt_and_conf(opt, conf); + + // Add the slow initialization services and retain their handles so bootstrap can depend + // on them. Both run in parallel; the slowest (HashRingService at 3s) sets the pace. + let backend_handle = server.add_service(background_service( + "backend_discovery", + BackendDiscoveryService, + )); + let hash_ring_handle = server.add_service(background_service("hash_ring", HashRingService)); + + // bootstrap_as_a_service() registers the bootstrap service (socket transfer from the old + // process + SIGUSR1 to the parent) and returns its ServiceHandle. Declaring the slow + // services as dependencies ensures bootstrap only runs once both are ready. + let bootstrap_handle = server.bootstrap_as_a_service(); + bootstrap_handle.add_dependencies([&backend_handle, &hash_ring_handle]); + + let mut http_service = ListeningService::new("hello_http".to_string(), HelloApp); + http_service.add_tcp("0.0.0.0:8000"); + + server + .add_service(http_service) + .add_dependency(backend_handle); + + server.run_forever(); +} From d7728cac9afa137b2cd75f645ba2685b9b912a89 Mon Sep 17 00:00:00 2001 From: Edward Wang Date: Fri, 27 Feb 2026 20:11:19 -0800 Subject: [PATCH 14/52] Add cancel-safe body and header writer primitives Add BodyWriter task API (send_body_task, write_current_body_task, send_finish_task, write_current_finish_task) and HeaderWriter for cancel-safe writes that can be used in tokio::select! loops. --- .bleep | 2 +- pingora-core/src/protocols/http/v1/body.rs | 1541 +++++++++++++++++- pingora-core/src/protocols/http/v1/header.rs | 449 +++++ pingora-core/src/protocols/http/v1/mod.rs | 1 + pingora-core/src/protocols/l4/stream.rs | 65 +- 5 files changed, 2022 insertions(+), 36 deletions(-) create mode 100644 pingora-core/src/protocols/http/v1/header.rs diff --git a/.bleep b/.bleep index a5fdccd2..8ca85d6c 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -3f438f804b9e954f7b882815fc9e70ebe9d572a0 \ No newline at end of file +ade24f55fc0b8c3be1b0a22da73cfe94058f811c \ No newline at end of file diff --git a/pingora-core/src/protocols/http/v1/body.rs b/pingora-core/src/protocols/http/v1/body.rs index 72899257..fbed3b11 100644 --- a/pingora-core/src/protocols/http/v1/body.rs +++ b/pingora-core/src/protocols/http/v1/body.rs @@ -20,9 +20,14 @@ use pingora_error::{ OrErr, Result, }; use std::fmt::Debug; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use crate::protocols::l4::stream::AsyncWriteVec; +use crate::protocols::l4::stream::{ + async_write_vec::{poll_write_all_buf, poll_write_vec_all_buf}, + AsyncWriteVec, +}; use crate::utils::BufRef; // TODO: make this dynamically adjusted @@ -905,14 +910,193 @@ pub enum BodyMode { type BM = BodyMode; +// ============================================================================ +// Cancel-safe body writing types +// ============================================================================ + +impl BodyMode { + /// Extract `(total, written)` from `ContentLength`, panicking on mismatch. + fn expect_content_length(&self) -> (usize, usize) { + match self { + BodyMode::ContentLength(total, written) => (*total, *written), + _ => panic!("wrong body mode: expected ContentLength, got {:?}", self), + } + } + + /// Extract `written` from `ChunkedEncoding`, panicking on mismatch. + fn expect_chunked(&self) -> usize { + match self { + BodyMode::ChunkedEncoding(written) => *written, + _ => panic!("wrong body mode: expected ChunkedEncoding, got {:?}", self), + } + } + + /// Extract `written` from `UntilClose`, panicking on mismatch. + fn expect_until_close(&self) -> usize { + match self { + BodyMode::UntilClose(written) => *written, + _ => panic!("wrong body mode: expected UntilClose, got {:?}", self), + } + } +} + +/// Type alias for the chunked encoding buffer chain +type ChunkedBuf = bytes::buf::Chain, &'static [u8]>; + +#[allow(dead_code)] +enum WriteBuf { + /// Simple bytes buffer + Simple(Bytes), + /// Chained buffer for chunked encoding or other complex writes + Chained(C), +} + +// Implement Buf for WriteBuf to delegate to the inner buffer +impl Buf for WriteBuf { + fn remaining(&self) -> usize { + match self { + WriteBuf::Simple(b) => b.remaining(), + WriteBuf::Chained(c) => c.remaining(), + } + } + + fn chunk(&self) -> &[u8] { + match self { + WriteBuf::Simple(b) => b.chunk(), + WriteBuf::Chained(c) => c.chunk(), + } + } + + fn advance(&mut self, cnt: usize) { + match self { + WriteBuf::Simple(b) => b.advance(cnt), + WriteBuf::Chained(c) => c.advance(cnt), + } + } +} + +#[allow(dead_code)] +enum WriteState { + /// No write in progress + Idle, + /// Writing data (original size, bytes remaining to write) + Writing(usize, WriteBuf), + /// Flushing after write (original size to return) + Flushing(usize), + /// Write complete (bytes written in this task) + Done(usize), + /// Write timed out - cannot be reused + TimedOut, +} + +// Custom Debug implementation since we can't derive it with futures +impl std::fmt::Debug for WriteState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + WriteState::Idle => write!(f, "Idle"), + WriteState::Writing(size, _buf) => { + write!(f, "Writing(size: {})", size) + } + WriteState::Flushing(size) => write!(f, "Flushing(size: {})", size), + WriteState::Done(size) => write!(f, "Done(size: {})", size), + WriteState::TimedOut => write!(f, "TimedOut"), + } + } +} + +#[allow(dead_code)] +enum FinishWriteState { + /// No finish task queued + NotStarted, + /// Finish queued but not started yet + Idle, + /// Writing last chunk marker (for chunked encoding) + WritingLastChunk(WriteBuf), + /// Flushing after writing last chunk + Flushing, + /// Finish complete + Done, +} + +// Custom Debug implementation since WriteBuf doesn't implement Debug +impl std::fmt::Debug for FinishWriteState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + FinishWriteState::NotStarted => write!(f, "NotStarted"), + FinishWriteState::Idle => write!(f, "Idle"), + FinishWriteState::WritingLastChunk(_) => write!(f, "WritingLastChunk"), + FinishWriteState::Flushing => write!(f, "Flushing"), + FinishWriteState::Done => write!(f, "Done"), + } + } +} + +/// Internal state for the cancel-safe body write state machine. +/// +/// Tracks the pending body bytes, write progress +/// (idle → writing → flushing → done), and an optional timeout. +struct SendBodyState { + /// Application bytes queued to be written + pending_bytes: Option, + /// Current write state for cancel-safe operations + write_state: WriteState, + /// Timeout duration for this write task + timeout_duration: Option, + /// Timeout future (only created if write returns Pending) + timeout_fut: Option + Send + Sync>>>, +} + +impl std::fmt::Debug for SendBodyState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SendBodyState") + .field("pending_bytes", &self.pending_bytes) + .field("write_state", &self.write_state) + .field("timeout_duration", &self.timeout_duration) + .field( + "timeout_fut", + &self.timeout_fut.as_ref().map(|_| "Some(Future)"), + ) + .finish() + } +} + +impl SendBodyState { + fn new() -> Self { + SendBodyState { + pending_bytes: None, + write_state: WriteState::Idle, + timeout_duration: None, + timeout_fut: None, + } + } +} + +/// Tracks how response body bytes are framed and written to the wire. +/// +/// Supports both a legacy async API (`write_body` / `finish`) and a cancel-safe +/// task API that can be driven inside a `tokio::select!` loop without losing +/// write progress. pub struct BodyWriter { pub body_mode: BodyMode, + // Boxed to reduce inline size. Only used by the cancel-safe proxy task API. + #[allow(dead_code)] + send_body_state: Box, + #[allow(dead_code)] + send_finish_state: FinishWriteState, +} + +impl Default for BodyWriter { + fn default() -> Self { + Self::new() + } } impl BodyWriter { pub fn new() -> Self { BodyWriter { body_mode: BM::ToSelect, + send_body_state: Box::new(SendBodyState::new()), + send_finish_state: FinishWriteState::NotStarted, } } @@ -1109,6 +1293,547 @@ impl BodyWriter { _ => panic!("wrong body mode: {:?}", self.body_mode), } } + + // ======================================================================== + // Cancel-safe body task API + // ======================================================================== + + #[cfg(test)] + pub fn has_pending_body_task(&self) -> bool { + self.send_body_state.pending_bytes.is_some() + || !matches!( + self.send_body_state.write_state, + WriteState::Idle | WriteState::Done(_) | WriteState::TimedOut + ) + } + + /// Queue application bytes as a body write task with an optional timeout. + /// This is a non-async function that just saves the bytes. + /// Call `write_current_body_task()` to actually perform the write. + /// + /// The timeout, if provided, will be enforced internally across all + /// write attempts, even if the write is cancelled and resumed via `tokio::select!`. + #[allow(dead_code)] + pub fn send_body_task(&mut self, bytes: Bytes, timeout: Option) { + assert!( + matches!( + self.send_body_state.write_state, + WriteState::Idle | WriteState::Done(_) + ), + "send_body_task called while previous task is still in progress: {:?}", + self.send_body_state.write_state + ); + self.send_body_state.pending_bytes = Some(bytes); + self.send_body_state.write_state = WriteState::Idle; + self.send_body_state.timeout_duration = timeout; + self.send_body_state.timeout_fut = None; + } + + /// Writes the current queued body task to the stream. + /// + /// ## Cancel-safety + /// + /// This function can be safely used in a `tokio::select!` loop. + /// Returns `Ok(Some(bytes_written))` when complete, `Ok(None)` if no bytes to write. + #[allow(dead_code)] + pub async fn write_current_body_task(&mut self, stream: &mut S) -> Result> + where + S: AsyncWrite + Unpin + Send, + { + // Use poll_fn to wrap our poll-based implementation + std::future::poll_fn(|cx| self.poll_write_current_body_task(cx, Pin::new(stream))).await + } + + /// Poll-based implementation for writing body tasks. + /// This is the core implementation that maintains state across cancellations. + fn poll_write_current_body_task( + &mut self, + cx: &mut Context<'_>, + stream: Pin<&mut S>, + ) -> Poll>> + where + S: AsyncWrite + Unpin + Send, + { + // Check if already timed out - don't allow reuse + if matches!(self.send_body_state.write_state, WriteState::TimedOut) { + return Poll::Ready(Error::e_explain( + WriteTimedout, + "write task previously timed out", + )); + } + + // Lazy timeout optimization: Poll write first, create timeout only if needed. + // + // This follows the pattern from `pingora_timeout::Timeout` to avoid allocating + // and registering timeout futures when writes complete immediately (the common case). + // + // Fast path: Write completes → return immediately, no timeout future created + // Slow path: Write blocks → lazily create timeout future and poll both + + // First, try the write operation + // Dispatch to the appropriate body mode handler + let result = match self.body_mode { + BM::Complete(_) => Poll::Ready(Ok(None)), + BM::ContentLength(_, _) => self.poll_write_content_length_body_task(cx, stream), + BM::ChunkedEncoding(_) => self.poll_write_chunked_body_task(cx, stream), + BM::UntilClose(_) => self.poll_write_until_close_body_task(cx, stream), + BM::ToSelect => Poll::Ready(Ok(None)), + }; + + // If write completed immediately, return without ever creating/polling timeout + if result.is_ready() { + return result; + } + + // Write returned Pending - lazily create and check timeout if duration is set + if let Some(duration) = self.send_body_state.timeout_duration { + let timeout = self.send_body_state.timeout_fut.get_or_insert_with(|| { + Box::pin(pingora_timeout::sleep(duration)) + as std::pin::Pin + Send + Sync>> + }); + + if timeout.as_mut().poll(cx).is_ready() { + // Timeout fired! Mark state as timed out and clear the timeout future + self.send_body_state.write_state = WriteState::TimedOut; + self.send_body_state.timeout_fut = None; + return Poll::Ready(Error::e_explain( + WriteTimedout, + "writing body task timed out", + )); + } + } + + // Both write and timeout are pending + Poll::Pending + } + + // ======================================================================== + // Cancel-safe finish task API + // ======================================================================== + + #[cfg(test)] + pub fn has_pending_finish_task(&self) -> bool { + !matches!( + self.send_finish_state, + FinishWriteState::NotStarted | FinishWriteState::Done + ) + } + + /// Queue a finish operation as a task. + /// This is a non-async function that just marks the finish as pending. + /// Call `write_current_finish_task()` to actually perform the finish. + /// + /// This API is stateful and cancel-safe - use it when you need to finish + /// the body in a `tokio::select!` loop or other cancellable context. + #[allow(dead_code)] + pub fn send_finish_task(&mut self) { + self.send_finish_state = FinishWriteState::Idle; + } + + /// Async function that performs the current queued finish task on the stream. + /// This function is cancel-safe and can be called in a `tokio::select!` loop. + /// Returns `Ok(Some(bytes_written))` when complete, `Ok(None)` if already complete. + /// + /// This API is stateful - it tracks progress across cancellations and can be + /// safely resumed after being dropped mid-execution. + #[allow(dead_code)] + pub async fn write_current_finish_task(&mut self, stream: &mut S) -> Result> + where + S: AsyncWrite + Unpin + Send, + { + // Use poll_fn to wrap our poll-based implementation + std::future::poll_fn(|cx| self.poll_write_current_finish_task(cx, Pin::new(stream))).await + } + + /// Poll-based implementation for finish tasks. + /// This is the core implementation that maintains state across cancellations. + fn poll_write_current_finish_task( + &mut self, + cx: &mut Context<'_>, + stream: Pin<&mut S>, + ) -> Poll>> + where + S: AsyncWrite + Unpin + Send, + { + // If no finish queued, return None + if matches!( + self.send_finish_state, + FinishWriteState::NotStarted | FinishWriteState::Done + ) { + return Poll::Ready(Ok(None)); + } + + // Route to body-mode-specific implementation + match self.body_mode { + BM::Complete(_) => Poll::Ready(Ok(None)), + BM::ContentLength(_, _) => self.poll_finish_content_length_task(cx, stream), + BM::ChunkedEncoding(_) => self.poll_finish_chunked_task(cx, stream), + BM::UntilClose(_) => self.poll_finish_until_close_task(cx, stream), + BM::ToSelect => Poll::Ready(Ok(None)), + } + } + + /// Finish content-length body - just validates and updates state. + /// No I/O needed since body write tasks already flushed after the last write. + fn poll_finish_content_length_task( + &mut self, + _cx: &mut Context<'_>, + _stream: Pin<&mut S>, + ) -> Poll>> + where + S: AsyncWrite + Unpin + Send, + { + let written = match self.body_mode { + BM::ContentLength(total, w) => { + if w < total { + self.send_finish_state = FinishWriteState::Done; + return Poll::Ready(Error::e_explain( + PREMATURE_BODY_END, + format!("Content-length: {total} bytes written: {w}"), + )); + } + w + } + _ => panic!("wrong body mode: {:?}", self.body_mode), + }; + + // All bytes written - just update state to Complete + self.body_mode = BM::Complete(written); + self.send_finish_state = FinishWriteState::Done; + Poll::Ready(Ok(Some(written))) + } + + /// Poll-based helper to finish chunked encoding body + fn poll_finish_chunked_task( + &mut self, + cx: &mut Context<'_>, + mut stream: Pin<&mut S>, + ) -> Poll>> + where + S: AsyncWrite + Unpin + Send, + { + let written = match self.body_mode { + BM::ChunkedEncoding(w) => w, + _ => panic!("wrong body mode: {:?}", self.body_mode), + }; + + loop { + match &mut self.send_finish_state { + FinishWriteState::Idle => { + // Start writing last chunk marker "0\r\n\r\n" + let buf = WriteBuf::Simple(Bytes::from_static(&LAST_CHUNK[..])); + self.send_finish_state = FinishWriteState::WritingLastChunk(buf); + } + FinishWriteState::WritingLastChunk(buf) => { + // Poll write_vec_all - write until all bytes are written + ready!(poll_write_vec_all_buf(cx, stream.as_mut(), buf)) + .map_err(|e| Error::because(WriteError, "while writing last chunk", e))?; + + // All bytes written, move to flushing state + self.send_finish_state = FinishWriteState::Flushing; + } + FinishWriteState::Flushing => { + // Poll flush + ready!(stream.as_mut().poll_flush(cx)) + .map_err(|e| Error::because(WriteError, "flushing after last chunk", e))?; + + // Flush complete! Update body_mode and mark done + self.body_mode = BM::Complete(written); + self.send_finish_state = FinishWriteState::Done; + return Poll::Ready(Ok(Some(written))); + } + FinishWriteState::Done => { + unreachable!( + "Done state should have been handled in poll_write_current_finish_task" + ) + } + FinishWriteState::NotStarted => { + unreachable!("NotStarted state should have been handled in poll_write_current_finish_task") + } + } + } + } + + /// Finish until-close body - just updates state. + /// No I/O needed since body write tasks already flushed after each write. + fn poll_finish_until_close_task( + &mut self, + _cx: &mut Context<'_>, + _stream: Pin<&mut S>, + ) -> Poll>> + where + S: AsyncWrite + Unpin + Send, + { + let written = match self.body_mode { + BM::UntilClose(w) => w, + _ => panic!("wrong body mode: {:?}", self.body_mode), + }; + + // Just update state to Complete + self.body_mode = BM::Complete(written); + self.send_finish_state = FinishWriteState::Done; + Poll::Ready(Ok(Some(written))) + } + + // ======================================================================== + // Internal helpers + // ======================================================================== + + /// Internal helper to poll a body task that writes in content-length mode + /// and flushes at end. + fn poll_write_content_length_body_task( + &mut self, + cx: &mut Context<'_>, + mut stream: Pin<&mut S>, + ) -> Poll>> + where + S: AsyncWrite + Unpin + Send, + { + // Move to Writing state if we're Idle + if matches!(self.send_body_state.write_state, WriteState::Idle) { + if let Some(mut bytes) = self.send_body_state.pending_bytes.take() { + let (total, written) = self.body_mode.expect_content_length(); + + // Check if we've already written everything + if written >= total { + self.send_body_state.write_state = WriteState::Done(0); + return Poll::Ready(Ok(None)); + } + + let original_size = bytes.len(); + let remaining = total - written; + + // Truncate bytes if they exceed content-length + if original_size > remaining { + warn!( + "Trying to write {} bytes over content-length: {}, truncating to {}", + original_size, total, remaining + ); + bytes.truncate(remaining); + } + + let bytes_to_write = bytes.len(); + self.send_body_state.write_state = + WriteState::Writing(bytes_to_write, WriteBuf::Simple(bytes)); + } else { + self.send_body_state.write_state = WriteState::Done(0); + return Poll::Ready(Ok(None)); + } + } + + // Handle Writing state - do the write, transition to Flushing or Done + if let WriteState::Writing(size, ref mut buf) = &mut self.send_body_state.write_state { + let bytes_written = *size; + + // Attempt write + match ready!(poll_write_all_buf(cx, stream.as_mut(), buf)) { + Ok(()) => { + // Write completed - update body_mode to track bytes written + let (total, written) = self.body_mode.expect_content_length(); + self.body_mode = BM::ContentLength(total, written + bytes_written); + + if written + bytes_written >= total { + // All content-length bytes written, flush needed + self.send_body_state.write_state = WriteState::Flushing(bytes_written); + } else { + // More bytes to come, no flush needed + self.send_body_state.write_state = WriteState::Done(bytes_written); + } + } + Err(e) => { + return Poll::Ready(Error::e_because(WriteError, "while writing body", e)) + } + } + } + + // Handle Flushing state - do the flush, transition to Done + if let WriteState::Flushing(size) = self.send_body_state.write_state { + let bytes_written = size; + + // Attempt flush + match ready!(stream.poll_flush(cx)) { + Ok(()) => { + // Flush completed - transition to Done + self.send_body_state.write_state = WriteState::Done(bytes_written); + } + Err(e) => return Poll::Ready(Error::e_because(WriteError, "flushing body", e)), + } + } + + // Return based on final state + match self.send_body_state.write_state { + WriteState::Done(size) => { + self.send_body_state.timeout_fut = None; + Poll::Ready(Ok(Some(size))) + } + WriteState::TimedOut => Poll::Ready(Error::e_explain( + WriteTimedout, + "write task previously timed out", + )), + WriteState::Writing(..) | WriteState::Flushing(..) => { + unreachable!("Writing/Flushing states should have been handled above or returned Pending via ready!") + } + WriteState::Idle => { + unreachable!("Idle state should have been handled in setup") + } + } + } + + /// Poll-based implementation for chunked encoding mode + fn poll_write_chunked_body_task( + &mut self, + cx: &mut Context<'_>, + mut stream: Pin<&mut S>, + ) -> Poll>> + where + S: AsyncWrite + Unpin + Send, + { + // Move to Writing state if we're Idle + if matches!(self.send_body_state.write_state, WriteState::Idle) { + if let Some(bytes) = self.send_body_state.pending_bytes.take() { + let application_bytes_size = bytes.len(); + + // Format the chunk: size\r\ndata\r\n + let chunk_size_header = format!("{:X}\r\n", application_bytes_size); + let output_buf = Bytes::from(chunk_size_header) + .chain(bytes) + .chain(&b"\r\n"[..]); + + // Store the chained buffer directly to avoid copying + self.send_body_state.write_state = + WriteState::Writing(application_bytes_size, WriteBuf::Chained(output_buf)); + } else { + self.send_body_state.write_state = WriteState::Done(0); + return Poll::Ready(Ok(None)); + } + } + + // Handle Writing state - do the write using vectored I/O, transition to Flushing + if let WriteState::Writing(size, ref mut buf) = &mut self.send_body_state.write_state { + let bytes_written = *size; + + // Attempt vectored write for chained buffer (chunk size + data + CRLF) + match ready!(poll_write_vec_all_buf(cx, stream.as_mut(), buf)) { + Ok(()) => { + // Write completed - update body_mode with application bytes (not wire bytes) + let written = self.body_mode.expect_chunked(); + self.body_mode = BM::ChunkedEncoding(written + bytes_written); + + // Chunked encoding always flushes + self.send_body_state.write_state = WriteState::Flushing(bytes_written); + } + Err(e) => { + return Poll::Ready(Error::e_because(WriteError, "while writing body", e)) + } + } + } + + // Handle Flushing state - do the flush, transition to Done + if let WriteState::Flushing(size) = self.send_body_state.write_state { + let bytes_written = size; + + // Attempt flush + match ready!(stream.poll_flush(cx)) { + Ok(()) => { + // Flush completed - transition to Done + self.send_body_state.write_state = WriteState::Done(bytes_written); + } + Err(e) => return Poll::Ready(Error::e_because(WriteError, "flushing body", e)), + } + } + + // Return based on final state + match self.send_body_state.write_state { + WriteState::Done(size) => { + self.send_body_state.timeout_fut = None; + Poll::Ready(Ok(Some(size))) + } + WriteState::TimedOut => Poll::Ready(Error::e_explain( + WriteTimedout, + "write task previously timed out", + )), + WriteState::Writing(..) | WriteState::Flushing(..) => { + unreachable!("Writing/Flushing states should have been handled above or returned Pending via ready!") + } + WriteState::Idle => { + unreachable!("Idle state should have been handled in setup") + } + } + } + + /// Poll-based implementation for UntilClose (close-delimited) body mode + fn poll_write_until_close_body_task( + &mut self, + cx: &mut Context<'_>, + mut stream: Pin<&mut S>, + ) -> Poll>> + where + S: AsyncWrite + Unpin + Send, + { + // Move to Writing state if we're Idle + if matches!(self.send_body_state.write_state, WriteState::Idle) { + if let Some(bytes) = self.send_body_state.pending_bytes.take() { + let original_size = bytes.len(); + self.send_body_state.write_state = + WriteState::Writing(original_size, WriteBuf::Simple(bytes)); + } else { + self.send_body_state.write_state = WriteState::Done(0); + return Poll::Ready(Ok(None)); + } + } + + // Handle Writing state - do the write, transition to Flushing + if let WriteState::Writing(size, ref mut buf) = &mut self.send_body_state.write_state { + let bytes_written = *size; + + // Attempt write + match ready!(poll_write_all_buf(cx, stream.as_mut(), buf)) { + Ok(()) => { + // Write completed - update body_mode to track bytes written + let written = self.body_mode.expect_until_close(); + self.body_mode = BM::UntilClose(written + bytes_written); + + // Close-delimited mode always flushes + self.send_body_state.write_state = WriteState::Flushing(bytes_written); + } + Err(e) => { + return Poll::Ready(Error::e_because(WriteError, "while writing body", e)) + } + } + } + + // Handle Flushing state - do the flush, transition to Done + if let WriteState::Flushing(size) = self.send_body_state.write_state { + let bytes_written = size; + + // Attempt flush + match ready!(stream.poll_flush(cx)) { + Ok(()) => { + // Flush completed - transition to Done + self.send_body_state.write_state = WriteState::Done(bytes_written); + } + Err(e) => return Poll::Ready(Error::e_because(WriteError, "flushing body", e)), + } + } + + // Return based on final state + match self.send_body_state.write_state { + WriteState::Done(size) => { + self.send_body_state.timeout_fut = None; + Poll::Ready(Ok(Some(size))) + } + WriteState::TimedOut => Poll::Ready(Error::e_explain( + WriteTimedout, + "write task previously timed out", + )), + WriteState::Writing(..) | WriteState::Flushing(..) => { + unreachable!("Writing/Flushing states should have been handled above or returned Pending via ready!") + } + WriteState::Idle => { + unreachable!("Idle state should have been handled in setup") + } + } + } } #[cfg(test)] @@ -1717,23 +2442,31 @@ mod tests { let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); assert_eq!(res, BufRef::new(0, 0)); assert_eq!(body_reader.body_state, ParseState::Chunked(0, 0, 2, 2)); - let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); - assert_eq!(res, BufRef::new(3, 1)); // input1 concat input2 - assert_eq!(&input2[1..2], body_reader.get_body(&res)); - assert_eq!(body_reader.body_state, ParseState::Chunked(1, 6, 11, 0)); - let res = body_reader.read_body(&mut mock_io).await.unwrap(); - assert_eq!(res, None); - assert_eq!(body_reader.body_state, ParseState::Complete(1)); - assert_eq!(body_reader.get_body_overread(), None); + let _res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); } #[tokio::test] - async fn read_with_body_partial_head_terminal_crlf() { + async fn read_with_body_partial_head_chunk_incomplete() { init_log(); let input1 = b"1\r"; - let input2 = b"\na\r\n0\r\n\r"; - let input3 = b"\n"; - let mut mock_io = Builder::new() + let mut mock_io = Builder::new().read(&input1[..]).build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::Chunked(0, 0, 2, 2)); + let res = body_reader.read_body(&mut mock_io).await; + assert!(res.is_err()); + assert_eq!(body_reader.body_state, ParseState::Done(0)); + } + + #[tokio::test] + async fn read_with_body_partial_head_terminal_crlf() { + init_log(); + let input1 = b"1\r"; + let input2 = b"\na\r\n0\r\n\r"; + let input3 = b"\n"; + let mut mock_io = Builder::new() .read(&input1[..]) .read(&input2[..]) .read(&input3[..]) @@ -1925,21 +2658,6 @@ mod tests { assert_eq!(body_reader.get_body_overread(), Some(&b"abc"[..])); } - #[tokio::test] - async fn read_with_body_partial_head_chunk_incomplete() { - init_log(); - let input1 = b"1\r"; - let mut mock_io = Builder::new().read(&input1[..]).build(); - let mut body_reader = BodyReader::new(false); - body_reader.init_chunked(b""); - let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); - assert_eq!(res, BufRef::new(0, 0)); - assert_eq!(body_reader.body_state, ParseState::Chunked(0, 0, 2, 2)); - let res = body_reader.read_body(&mut mock_io).await; - assert!(res.is_err()); - assert_eq!(body_reader.body_state, ParseState::Done(0)); - } - #[tokio::test] async fn read_with_body_trailers() { init_log(); @@ -2319,7 +3037,7 @@ mod tests { } #[tokio::test] - async fn write_body_http10() { + async fn write_body_until_close() { init_log(); let data = b"a"; let mut mock_io = Builder::new().write(&data[..]).write(&data[..]).build(); @@ -2345,3 +3063,768 @@ mod tests { assert_eq!(body_writer.body_mode, BodyMode::Complete(2)); } } + +#[cfg(test)] +mod test_body_task_api { + use super::*; + use tokio_test::io::Builder; + + // Cancel-safety tests use tokio::select! to race a short sleep against a mock + // I/O wait, simulating cancellation. We use #[tokio::test(start_paused = true)] + // on these tests so that tokio auto-advances time deterministically rather than + // relying on wall-clock timing. + + fn init_log() { + let _ = env_logger::builder().is_test(true).try_init(); + } + + #[tokio::test] + async fn test_has_pending_body_task() { + init_log(); + let data = b"test data"; + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data.len()); + + // Initially should have no pending task + assert!(!body_writer.has_pending_body_task()); + + // After queuing bytes, should have pending task + body_writer.send_body_task(Bytes::from_static(data), None); + assert!(body_writer.has_pending_body_task()); + } + + #[tokio::test(start_paused = true)] + async fn cancel_safe_content_length_write() { + init_log(); + let data = b"Hello, World!"; + + // Create a mock stream that will block to allow cancellation + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(100)) + .write(data) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data.len()); + + // Queue the bytes to write + body_writer.send_body_task(Bytes::from_static(data), None); + + // Use tokio::select! loop - keep looping until write completes + let mut cancel_count = 0; + let mut total_bytes_written = 0; + + loop { + // Break if no pending writes + if !body_writer.has_pending_body_task() { + break; + } + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => { + // Timeout fires first, cancelling the write + cancel_count += 1; + } + result = body_writer.write_current_body_task(&mut mock_io) => { + // Write completed + assert!(result.is_ok(), "Write should succeed"); + if let Ok(Some(n)) = result { + total_bytes_written += n; + } + } + } + } + + assert!( + cancel_count > 0, + "At least one cancellation should have occurred" + ); + assert_eq!( + total_bytes_written, + data.len(), + "Should have written all application bytes" + ); + assert_eq!( + body_writer.body_mode, + BodyMode::ContentLength(data.len(), data.len()) + ); + + // Now test finish() in a select loop as well + let mut mock_io_finish = Builder::new().build(); + + loop { + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(5)) => { + // Allow cancellation attempts + } + result = body_writer.finish(&mut mock_io_finish) => { + assert!(result.is_ok()); + break; + } + } + } + + assert_eq!(body_writer.body_mode, BodyMode::Complete(data.len())); + } + + #[tokio::test(start_paused = true)] + async fn cancel_safe_chunked_write() { + init_log(); + let data = b"abcdefghij"; + let expected_output = b"A\r\nabcdefghij\r\n"; + + // Mock stream that blocks to allow cancellation + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(100)) + .write(expected_output) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_chunked(); + + // Queue bytes + body_writer.send_body_task(Bytes::from_static(data), None); + + // Use select loop - keep looping until write completes + let mut cancel_count = 0; + let mut total_bytes_written = 0; + + loop { + if !body_writer.has_pending_body_task() { + break; + } + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => { + cancel_count += 1; + } + result = body_writer.write_current_body_task(&mut mock_io) => { + assert!(result.is_ok()); + if let Ok(Some(n)) = result { + total_bytes_written += n; + } + } + } + } + + assert!(cancel_count > 0, "Should have cancelled at least once"); + assert_eq!( + total_bytes_written, + data.len(), + "Should have written all application bytes" + ); + assert_eq!(body_writer.body_mode, BodyMode::ChunkedEncoding(data.len())); + + // Test finish() with select loop - must write terminating chunk + let mut mock_io_finish = Builder::new() + .wait(std::time::Duration::from_millis(50)) + .write(&LAST_CHUNK[..]) // Expect 0\r\n\r\n + .build(); + + loop { + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(5)) => {} + result = body_writer.finish(&mut mock_io_finish) => { + assert!(result.is_ok()); + break; + } + } + } + + assert_eq!(body_writer.body_mode, BodyMode::Complete(data.len())); + } + + #[tokio::test(start_paused = true)] + async fn cancel_safe_until_close_write() { + init_log(); + let data = b"test data"; + + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(100)) + .write(data) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_close_delimited(); + + body_writer.send_body_task(Bytes::from_static(data), None); + + // Use select loop - keep looping until write completes + let mut cancel_count = 0; + let mut total_bytes_written = 0; + + loop { + if !body_writer.has_pending_body_task() { + break; + } + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => { + cancel_count += 1; + } + result = body_writer.write_current_body_task(&mut mock_io) => { + assert!(result.is_ok()); + if let Ok(Some(n)) = result { + total_bytes_written += n; + } + } + } + } + + assert!(cancel_count > 0); + assert_eq!( + total_bytes_written, + data.len(), + "Should have written all application bytes" + ); + + // Test finish() with select loop + let mut mock_io_finish = Builder::new().build(); + + loop { + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(5)) => {} + result = body_writer.finish(&mut mock_io_finish) => { + assert!(result.is_ok()); + break; + } + } + } + } + + #[tokio::test(start_paused = true)] + async fn cancel_safe_multiple_cancellations() { + init_log(); + let data = b"Long test data that requires multiple writes"; + + // Create a mock that blocks multiple times + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(50)) + .write(&data[..15]) + .wait(std::time::Duration::from_millis(50)) + .write(&data[15..30]) + .wait(std::time::Duration::from_millis(50)) + .write(&data[30..]) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data.len()); + body_writer.send_body_task(Bytes::from_static(data), None); + + // Loop until write completes, allowing cancellations + let mut cancel_count = 0; + let mut total_bytes_written = 0; + + loop { + if !body_writer.has_pending_body_task() { + break; + } + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(5)) => { + cancel_count += 1; + } + result = body_writer.write_current_body_task(&mut mock_io) => { + assert!(result.is_ok()); + if let Ok(Some(n)) = result { + total_bytes_written += n; + } + } + } + } + + assert!(cancel_count >= 2, "Should have multiple cancellations"); + assert_eq!( + total_bytes_written, + data.len(), + "Should have written all application bytes" + ); + + // Test finish with select loop + let mut mock_io_finish = Builder::new().build(); + + loop { + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(5)) => {} + result = body_writer.finish(&mut mock_io_finish) => { + assert!(result.is_ok()); + break; + } + } + } + } + + #[tokio::test(start_paused = true)] + async fn cancel_safe_partial_writes() { + init_log(); + let data = b"12345678901234567890"; // 20 bytes + + // Simulate partial writes with blocking + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(50)) + .write(&data[..7]) + .wait(std::time::Duration::from_millis(50)) + .write(&data[7..14]) + .wait(std::time::Duration::from_millis(50)) + .write(&data[14..]) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data.len()); + body_writer.send_body_task(Bytes::from_static(data), None); + + let mut cancel_count = 0; + let mut total_bytes_written = 0; + + loop { + if !body_writer.has_pending_body_task() { + break; + } + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => { + cancel_count += 1; + } + result = body_writer.write_current_body_task(&mut mock_io) => { + assert!(result.is_ok()); + if let Ok(Some(n)) = result { + total_bytes_written += n; + } + } + } + } + + assert!(cancel_count > 0); + assert_eq!( + total_bytes_written, + data.len(), + "Should have written all application bytes" + ); + + // Test finish in select loop + let mut mock_io_finish = Builder::new() + .wait(std::time::Duration::from_millis(30)) + .build(); + + loop { + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(5)) => {} + result = body_writer.finish(&mut mock_io_finish) => { + assert!(result.is_ok(), "Finish should succeed after cancel-safe writes"); + break; + } + } + } + } + + #[tokio::test] + async fn test_task_write_timeout() { + init_log(); + let data = b"test data"; + + // Create a mock that blocks forever + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_secs(1000)) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data.len()); + + // Queue the task with a timeout + body_writer.send_body_task( + Bytes::from_static(data), + Some(std::time::Duration::from_millis(50)), + ); + + // The write should timeout + let result = body_writer.write_current_body_task(&mut mock_io).await; + assert!(result.is_err(), "Write should timeout"); + + // Check that it's a timeout error + if let Err(e) = result { + assert_eq!(e.etype(), &WriteTimedout); + } + } + + // Even if the user's select! cancels the write, the internal timeout + // should continue counting across cancellations. + #[tokio::test] + async fn test_task_timeout_persists_across_cancellations() { + init_log(); + let data = b"test data"; + + // Create a mock that blocks for a while + // Since timeout is 100ms and this waits 200ms, the write should never happen + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(200)) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data.len()); + + // Queue the task with a 100ms timeout + body_writer.send_body_task( + Bytes::from_static(data), + Some(std::time::Duration::from_millis(100)), + ); + + let mut attempts = 0; + let mut timedout = false; + + // Try to write in a loop, but cancel early each time + // The timeout should still fire even though we're cancelling + loop { + if !body_writer.has_pending_body_task() { + break; + } + + attempts += 1; + + tokio::select! { + // Cancel after just 10ms each time + _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => { + // Cancelled by our select, continue looping + continue; + } + result = body_writer.write_current_body_task(&mut mock_io) => { + match result { + Ok(_) => { + // Write succeeded before timeout + break; + } + Err(e) if e.etype() == &WriteTimedout => { + // Timeout fired! + timedout = true; + break; + } + Err(e) => { + panic!("Unexpected error: {:?}", e); + } + } + } + } + } + + assert!(timedout, "Timeout should have fired despite cancellations"); + assert!( + attempts >= 5, + "Should have had multiple attempts before timeout" + ); + } + + #[tokio::test] + async fn test_task_write_succeeds_within_timeout() { + init_log(); + let data = b"Hello, World!"; + + // Create a mock that completes quickly + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(20)) + .write(data) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data.len()); + + // Queue with a generous timeout + body_writer.send_body_task( + Bytes::from_static(data), + Some(std::time::Duration::from_millis(500)), + ); + + // Write should succeed + let result = body_writer.write_current_body_task(&mut mock_io).await; + assert!(result.is_ok(), "Write should succeed: {:?}", result); + assert_eq!(result.unwrap(), Some(data.len())); + } + + #[tokio::test] + async fn test_task_write_no_timeout() { + init_log(); + let data = b"test data"; + + // Create a mock that takes a bit of time but eventually succeeds + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(100)) + .write(data) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data.len()); + + // Queue without timeout + body_writer.send_body_task(Bytes::from_static(data), None); + + // Write should eventually succeed + let result = body_writer.write_current_body_task(&mut mock_io).await; + assert!(result.is_ok(), "Write should succeed without timeout"); + assert_eq!(result.unwrap(), Some(data.len())); + } + + #[tokio::test] + async fn test_task_chunked_write_timeout() { + init_log(); + let data = b"chunked data"; + + // Create a mock that blocks + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_secs(1000)) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_chunked(); + + // Queue with short timeout + body_writer.send_body_task( + Bytes::from_static(data), + Some(std::time::Duration::from_millis(50)), + ); + + // Should timeout + let result = body_writer.write_current_body_task(&mut mock_io).await; + assert!(result.is_err()); + if let Err(e) = result { + assert_eq!(e.etype(), &WriteTimedout); + } + } + + #[tokio::test] + async fn test_task_timeout_reset_on_new_task() { + init_log(); + let data1 = b"first"; + let data2 = b"second"; + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data1.len() + data2.len()); + + // Queue first task with short timeout + body_writer.send_body_task( + Bytes::from_static(data1), + Some(std::time::Duration::from_millis(50)), + ); + + // Wait a bit but don't let it timeout yet + tokio::time::sleep(std::time::Duration::from_millis(30)).await; + + // Queue a new task with a longer timeout + // This should reset/replace the timeout + body_writer.send_body_task( + Bytes::from_static(data2), + Some(std::time::Duration::from_millis(500)), + ); + + // Create a mock that takes some time + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(100)) + .write(data2) + .build(); + + // The second write should succeed with its own timeout + let result = body_writer.write_current_body_task(&mut mock_io).await; + assert!( + result.is_ok(), + "Second task should succeed with new timeout" + ); + } + + #[tokio::test] + async fn test_task_timeout_with_partial_writes() { + init_log(); + let data1 = b"first"; + let data2 = b"second"; + let data3 = b"third"; + + // Mock that writes data1 quickly, data2 with delay, data3 blocks forever + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(10)) + .write(data1) + .wait(std::time::Duration::from_millis(40)) + .write(data2) + .wait(std::time::Duration::from_secs(1000)) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data1.len() + data2.len() + data3.len()); + + let mut total_written = 0; + + // First write - should succeed within timeout + body_writer.send_body_task( + Bytes::from_static(data1), + Some(std::time::Duration::from_millis(100)), + ); + let result = body_writer.write_current_body_task(&mut mock_io).await; + assert!(result.is_ok()); + total_written += result.unwrap().unwrap(); + + // Second write - should succeed within timeout + body_writer.send_body_task( + Bytes::from_static(data2), + Some(std::time::Duration::from_millis(100)), + ); + let result = body_writer.write_current_body_task(&mut mock_io).await; + assert!(result.is_ok()); + total_written += result.unwrap().unwrap(); + + // Third write - should timeout + body_writer.send_body_task( + Bytes::from_static(data3), + Some(std::time::Duration::from_millis(50)), + ); + let result = body_writer.write_current_body_task(&mut mock_io).await; + assert!(result.is_err()); + assert_eq!(result.unwrap_err().etype(), &WriteTimedout); + + // We should have written data1 and data2 but not data3 + assert_eq!(total_written, data1.len() + data2.len()); + assert!( + total_written < data1.len() + data2.len() + data3.len(), + "Should not have written all data" + ); + } + + // Cancel-safe finish task for chunked encoding: send_finish_task() queues + // the terminating chunk, write_current_finish_task() writes it and can be + // cancelled and resumed in a select! loop. + #[tokio::test(start_paused = true)] + async fn cancel_safe_finish_task_chunked() { + init_log(); + + let data = Bytes::from("hello"); + let expected_chunk = b"5\r\nhello\r\n"; + + let mut mock_io = Builder::new().write(expected_chunk).build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_chunked(); + + // Write body data via task API + body_writer.send_body_task(data, None); + body_writer + .write_current_body_task(&mut mock_io) + .await + .unwrap(); + + // Queue the finish task + body_writer.send_finish_task(); + assert!(body_writer.has_pending_finish_task()); + + // Write the finish in a select! loop with cancellations + let mut mock_io_finish = Builder::new() + .wait(std::time::Duration::from_millis(100)) + .write(b"0\r\n\r\n") + .build(); + + let mut cancel_count = 0; + + loop { + if !body_writer.has_pending_finish_task() { + break; + } + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => { + cancel_count += 1; + } + result = body_writer.write_current_finish_task(&mut mock_io_finish) => { + assert!(result.is_ok()); + break; + } + } + } + + assert!(cancel_count > 0, "Should have cancelled at least once"); + assert!(matches!(body_writer.body_mode, BodyMode::Complete(_))); + } + + // Finish task for content-length is a no-op (no terminating chunk needed), + // but it should still transition body_mode to Complete. + #[tokio::test] + async fn finish_task_content_length() { + init_log(); + + let data = b"hello"; + let mut mock_io = Builder::new().write(data).build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data.len()); + + body_writer.send_body_task(Bytes::from_static(data), None); + body_writer + .write_current_body_task(&mut mock_io) + .await + .unwrap(); + + body_writer.send_finish_task(); + let mut mock_io_finish = Builder::new().build(); + let result = body_writer + .write_current_finish_task(&mut mock_io_finish) + .await; + assert!(result.is_ok()); + assert!(matches!(body_writer.body_mode, BodyMode::Complete(_))); + } + + // Verifies that body_mode byte tracking is correct when writing + // content-length body in multiple chunks. Each intermediate chunk + // does not trigger a flush; the body_mode must still accumulate + // bytes correctly so that finish_task succeeds. + #[tokio::test] + async fn content_length_body_mode_tracks_across_chunks() { + init_log(); + + let chunk1 = b"Hello"; + let chunk2 = b", World!"; + let total_len = chunk1.len() + chunk2.len(); // 13 + + // Mock expects both writes; the final write triggers a flush internally + let mut mock_io = Builder::new().write(chunk1).write(chunk2).build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(total_len); + + // Write first chunk (intermediate, no flush expected) + body_writer.send_body_task(Bytes::from_static(chunk1), None); + let result = body_writer + .write_current_body_task(&mut mock_io) + .await + .unwrap(); + assert_eq!(result, Some(chunk1.len())); + assert!( + !body_writer.finished(), + "Should not be finished after first chunk" + ); + + // Verify body_mode tracks the bytes from the first chunk + assert!( + matches!(body_writer.body_mode, BodyMode::ContentLength(total, written) + if total == total_len && written == chunk1.len()), + "body_mode should reflect bytes written so far, got: {:?}", + body_writer.body_mode + ); + + // Write second chunk (final, completes content-length) + body_writer.send_body_task(Bytes::from_static(chunk2), None); + let result = body_writer + .write_current_body_task(&mut mock_io) + .await + .unwrap(); + assert_eq!(result, Some(chunk2.len())); + assert!( + body_writer.finished(), + "Should be finished after all bytes written" + ); + + // Finish should succeed since all content-length bytes were written + body_writer.send_finish_task(); + let mut mock_io_finish = Builder::new().build(); + let result = body_writer + .write_current_finish_task(&mut mock_io_finish) + .await; + assert!( + result.is_ok(), + "finish_task should succeed when all content-length bytes written" + ); + assert!(matches!(body_writer.body_mode, BodyMode::Complete(_))); + } +} diff --git a/pingora-core/src/protocols/http/v1/header.rs b/pingora-core/src/protocols/http/v1/header.rs new file mode 100644 index 00000000..39eb9f1b --- /dev/null +++ b/pingora-core/src/protocols/http/v1/header.rs @@ -0,0 +1,449 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Cancel-safe header writing for HTTP/1.x + +use bytes::Bytes; +use pingora_error::{Error, ErrorType::*, Result}; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; +use tokio::io::AsyncWrite; + +use crate::protocols::l4::stream::async_write_vec::poll_write_all_buf; + +#[allow(dead_code)] +enum HeaderWriteState { + /// No write in progress + Idle, + /// Writing header bytes (original size, buffer) + Writing(usize, Bytes), + /// Flushing after write (original size to return) + Flushing(usize), + /// Write complete + Done, + /// Write timed out - cannot be reused + TimedOut, +} + +// Custom Debug implementation +impl std::fmt::Debug for HeaderWriteState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + HeaderWriteState::Idle => write!(f, "Idle"), + HeaderWriteState::Writing(size, _) => write!(f, "Writing(size: {})", size), + HeaderWriteState::Flushing(size) => write!(f, "Flushing(size: {})", size), + HeaderWriteState::Done => write!(f, "Done"), + HeaderWriteState::TimedOut => write!(f, "TimedOut"), + } + } +} + +/// Internal state for the cancel-safe header write state machine. +/// +/// Tracks the pending header bytes, write progress (idle → writing → flushing → done), +/// and an optional timeout that is lazily created on the first `Pending` poll. +#[allow(dead_code)] +struct SendHeaderState { + /// Serialized header bytes ready to be written + pending_header: Option, + /// Whether to flush after writing + should_flush: bool, + /// Current write state + write_state: HeaderWriteState, + /// Timeout duration for this write task + timeout_duration: Option, + /// Timeout future (only created if write returns Pending) + timeout_fut: Option + Send + Sync>>>, +} + +// Custom Debug implementation since timeout_fut doesn't implement Debug +impl std::fmt::Debug for SendHeaderState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SendHeaderState") + .field("pending_header", &self.pending_header) + .field("should_flush", &self.should_flush) + .field("write_state", &self.write_state) + .field("timeout_duration", &self.timeout_duration) + .field( + "timeout_fut", + &self.timeout_fut.as_ref().map(|_| "Some(Future)"), + ) + .finish() + } +} + +impl SendHeaderState { + #[allow(dead_code)] + fn new() -> Self { + SendHeaderState { + pending_header: None, + should_flush: false, + write_state: HeaderWriteState::Idle, + timeout_duration: None, + timeout_fut: None, + } + } +} + +/// Cancel-safe header writer for HTTP/1.x response headers. +/// +/// This writer allows response headers to be written to a downstream connection +/// inside a `tokio::select!` loop without losing progress. If the write is +/// cancelled (e.g. because another branch of the select fires first), the +/// partially-written state is preserved and will be resumed on the next call to +/// [`write_current_header_task`](Self::write_current_header_task). +/// +/// ## Usage +/// +/// 1. Call [`send_header_task`](Self::send_header_task) with pre-serialized +/// header bytes, a flush flag, and an optional timeout. +/// 2. Await [`write_current_header_task`](Self::write_current_header_task) +/// (possibly inside `tokio::select!`). The method returns `Ok(bytes_written)` +/// on success. +/// +/// A timeout, if set, is enforced *across* cancellations — the clock keeps +/// ticking even when the future is dropped and re-polled. +#[allow(dead_code)] +pub struct HeaderWriter { + // Boxed to reduce inline size. Only used by the cancel-safe proxy task API. + send_header_state: Box, +} + +impl Default for HeaderWriter { + fn default() -> Self { + Self::new() + } +} + +impl HeaderWriter { + #[allow(dead_code)] + pub fn new() -> Self { + HeaderWriter { + send_header_state: Box::new(SendHeaderState::new()), + } + } + + #[cfg(test)] + pub fn has_pending_header_task(&self) -> bool { + self.send_header_state.pending_header.is_some() + || !matches!( + self.send_header_state.write_state, + HeaderWriteState::Idle | HeaderWriteState::Done | HeaderWriteState::TimedOut + ) + } + + /// Queue serialized header bytes as a write task with an optional timeout. + /// This is a non-async function that just saves the bytes. + /// Call [`write_current_header_task`](Self::write_current_header_task) to actually perform the write. + #[allow(dead_code)] + pub fn send_header_task( + &mut self, + header_bytes: Bytes, + should_flush: bool, + timeout: Option, + ) { + assert!( + matches!( + self.send_header_state.write_state, + HeaderWriteState::Idle | HeaderWriteState::Done + ), + "send_header_task called while previous task is still in progress: {:?}", + self.send_header_state.write_state + ); + self.send_header_state.pending_header = Some(header_bytes); + self.send_header_state.should_flush = should_flush; + self.send_header_state.write_state = HeaderWriteState::Idle; + self.send_header_state.timeout_duration = timeout; + self.send_header_state.timeout_fut = None; + } + + /// Async function that writes the current queued header task to the stream. + /// This function is cancel-safe and can be called in a `tokio::select!` loop. + /// Returns `Ok(bytes_written)` when complete, `Ok(0)` if no bytes to write. + #[allow(dead_code)] + pub async fn write_current_header_task(&mut self, stream: &mut S) -> Result + where + S: AsyncWrite + Unpin + Send, + { + std::future::poll_fn(|cx| self.poll_write_current_header_task(cx, Pin::new(stream))).await + } + + /// Poll-based implementation for writing the current header task. + fn poll_write_current_header_task( + &mut self, + cx: &mut Context<'_>, + stream: Pin<&mut S>, + ) -> Poll> + where + S: AsyncWrite + Unpin + Send, + { + // Check if already timed out - don't allow reuse + if matches!( + self.send_header_state.write_state, + HeaderWriteState::TimedOut + ) { + return Poll::Ready(Error::e_explain( + WriteTimedout, + "header write task previously timed out", + )); + } + + // First, try the write operation + match self.poll_do_write_header_and_flush(cx, stream) { + Poll::Ready(Ok(size)) => { + // Write completed! Clear timeout and return + if matches!(self.send_header_state.write_state, HeaderWriteState::Done) { + self.send_header_state.timeout_fut = None; + } + return Poll::Ready(Ok(size)); + } + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => { + // Write is pending - now check timeout + } + } + + // Lazy timeout optimization: Polls write first, creates timeout only if needed. + // This follows the pattern from `pingora_timeout::Timeout` to avoid allocating + // timeout futures when writes complete immediately (the common case). + if let Some(duration) = self.send_header_state.timeout_duration { + let timeout = self.send_header_state.timeout_fut.get_or_insert_with(|| { + Box::pin(pingora_timeout::sleep(duration)) + as std::pin::Pin + Send + Sync>> + }); + + if timeout.as_mut().poll(cx).is_ready() { + // Timeout fired! + self.send_header_state.write_state = HeaderWriteState::TimedOut; + self.send_header_state.timeout_fut = None; + return Poll::Ready(Error::e_explain( + WriteTimedout, + "writing header task timed out", + )); + } + } + + // Both write and timeout are pending + Poll::Pending + } + + /// Poll-based helper to write header bytes and optionally flush. + /// Handles state transitions explicitly. + fn poll_do_write_header_and_flush( + &mut self, + cx: &mut Context<'_>, + mut stream: Pin<&mut S>, + ) -> Poll> + where + S: AsyncWrite + Unpin + Send, + { + // Handle Idle state - take pending header and transition to Writing + if matches!(self.send_header_state.write_state, HeaderWriteState::Idle) { + if let Some(header_bytes) = self.send_header_state.pending_header.take() { + let size = header_bytes.len(); + self.send_header_state.write_state = HeaderWriteState::Writing(size, header_bytes); + } else { + // No pending header + self.send_header_state.write_state = HeaderWriteState::Done; + return Poll::Ready(Ok(0)); + } + } + + // Write if in Writing state + if let HeaderWriteState::Writing(original_size, ref mut buf) = + self.send_header_state.write_state + { + let size = original_size; + ready!(poll_write_all_buf(cx, stream.as_mut(), buf)) + .map_err(|e| Error::because(WriteError, "writing response header", e))?; + + // Write complete - transition to next state + if self.send_header_state.should_flush { + self.send_header_state.write_state = HeaderWriteState::Flushing(size); + } else { + self.send_header_state.write_state = HeaderWriteState::Done; + return Poll::Ready(Ok(size)); + } + } + + // Handle the state after writing (or if we started in a non-Writing state) + match self.send_header_state.write_state { + HeaderWriteState::Flushing(size) => { + ready!(stream.as_mut().poll_flush(cx)) + .map_err(|e| Error::because(WriteError, "flushing response header", e))?; + // Flush complete - transition to Done + self.send_header_state.write_state = HeaderWriteState::Done; + Poll::Ready(Ok(size)) + } + HeaderWriteState::Done => Poll::Ready(Ok(0)), + HeaderWriteState::TimedOut => Poll::Ready(Error::e_explain( + WriteTimedout, + "header write task previously timed out", + )), + HeaderWriteState::Idle => { + unreachable!("Idle state should have been handled above") + } + HeaderWriteState::Writing(..) => { + unreachable!("Writing state should have been handled above") + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio_test::io::Builder; + + fn init_log() { + let _ = env_logger::builder().is_test(true).try_init(); + } + + #[tokio::test] + async fn test_simple_header_write() { + init_log(); + let header_data = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\n"; + + let mut mock_io = Builder::new().write(header_data).build(); + + let mut header_writer = HeaderWriter::new(); + header_writer.send_header_task(Bytes::from_static(header_data), false, None); + + let result = header_writer.write_current_header_task(&mut mock_io).await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), header_data.len()); + } + + #[tokio::test] + async fn test_header_write_with_flush() { + init_log(); + let header_data = b"HTTP/1.1 200 OK\r\n\r\n"; + + let mut mock_io = Builder::new().write(header_data).build(); + + let mut header_writer = HeaderWriter::new(); + header_writer.send_header_task(Bytes::from_static(header_data), true, None); + + let result = header_writer.write_current_header_task(&mut mock_io).await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), header_data.len()); + } + + // Uses start_paused for deterministic timer-based cancellation in select! + #[tokio::test(start_paused = true)] + async fn test_cancel_safe_header_write() { + init_log(); + let header_data = b"HTTP/1.1 200 OK\r\nServer: pingora\r\n\r\n"; + + // Mock that blocks to allow cancellation + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(100)) + .write(header_data) + .build(); + + let mut header_writer = HeaderWriter::new(); + header_writer.send_header_task(Bytes::from_static(header_data), false, None); + + let mut cancel_count = 0; + + loop { + if !header_writer.has_pending_header_task() { + break; + } + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => { + cancel_count += 1; + } + result = header_writer.write_current_header_task(&mut mock_io) => { + assert!(result.is_ok()); + assert_eq!(result.unwrap(), header_data.len()); + break; + } + } + } + + assert!(cancel_count > 0, "Should have cancelled at least once"); + } + + #[tokio::test] + async fn test_header_write_timeout() { + init_log(); + let header_data = b"HTTP/1.1 200 OK\r\n\r\n"; + + // Mock that blocks forever + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_secs(1000)) + .build(); + + let mut header_writer = HeaderWriter::new(); + header_writer.send_header_task( + Bytes::from_static(header_data), + false, + Some(std::time::Duration::from_millis(50)), + ); + + let result = header_writer.write_current_header_task(&mut mock_io).await; + assert!(result.is_err()); + assert_eq!(result.unwrap_err().etype(), &WriteTimedout); + } + + #[tokio::test] + async fn test_header_write_timeout_persists() { + init_log(); + let header_data = b"HTTP/1.1 200 OK\r\n\r\n"; + + // Mock that blocks for a while + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(200)) + .build(); + + let mut header_writer = HeaderWriter::new(); + header_writer.send_header_task( + Bytes::from_static(header_data), + false, + Some(std::time::Duration::from_millis(100)), + ); + + let mut attempts = 0; + let mut timedout = false; + + loop { + if !header_writer.has_pending_header_task() { + break; + } + + attempts += 1; + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => { + continue; + } + result = header_writer.write_current_header_task(&mut mock_io) => { + match result { + Ok(_) => break, + Err(e) if e.etype() == &WriteTimedout => { + timedout = true; + break; + } + Err(e) => panic!("Unexpected error: {:?}", e), + } + } + } + } + + assert!(timedout, "Timeout should have fired"); + assert!(attempts >= 5, "Should have had multiple attempts"); + } +} diff --git a/pingora-core/src/protocols/http/v1/mod.rs b/pingora-core/src/protocols/http/v1/mod.rs index 19602491..53acaec9 100644 --- a/pingora-core/src/protocols/http/v1/mod.rs +++ b/pingora-core/src/protocols/http/v1/mod.rs @@ -17,4 +17,5 @@ pub(crate) mod body; pub mod client; pub mod common; +pub(crate) mod header; pub mod server; diff --git a/pingora-core/src/protocols/l4/stream.rs b/pingora-core/src/protocols/l4/stream.rs index 4aa70f70..ddbaceb1 100644 --- a/pingora-core/src/protocols/l4/stream.rs +++ b/pingora-core/src/protocols/l4/stream.rs @@ -814,14 +814,67 @@ pub mod async_write_vec { fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { let me = &mut *self; - while me.buf.has_remaining() { - let n = ready!(Pin::new(&mut *me.writer).poll_write_vec(ctx, me.buf))?; - if n == 0 { - return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); - } + poll_write_vec_all_buf(ctx, Pin::new(&mut *me.writer), me.buf) + } + } + + /// Primitive poll function to write ALL bytes from a buffer using vectored writes. + /// Keeps polling `poll_write_vec` until the entire buffer is written. + /// The buffer is advanced as bytes are written. + /// + /// Returns Poll::Ready(Ok(())) when all bytes are written. + /// Returns WriteZero error if poll_write_vec returns 0. + /// + /// This is essentially a polling form of tokio's + /// [`write_all_buf`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncWriteExt.html#method.write_all_buf). + // TODO: we should be able to switch over to polling the future from tokio AsyncWriteExt directly, + // for now we continue to use the old trait. + pub fn poll_write_vec_all_buf( + ctx: &mut Context<'_>, + mut writer: Pin<&mut W>, + buf: &mut B, + ) -> Poll> + where + W: AsyncWriteVec + ?Sized, + B: Buf, + { + while buf.has_remaining() { + let n = ready!(writer.as_mut().poll_write_vec(ctx, buf))?; + if n == 0 { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); } - Poll::Ready(Ok(())) } + Poll::Ready(Ok(())) + } + + /// Primitive poll function to write ALL bytes from a buffer using regular writes. + /// Keeps polling `poll_write` until the entire buffer is written. + /// The buffer is advanced as bytes are written. + /// + /// Returns Poll::Ready(Ok(())) when all bytes are written. + /// Returns WriteZero error if poll_write returns 0. + /// + /// This is essentially a polling form of tokio's + /// [`write_all_buf`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncWriteExt.html#method.write_all_buf) + /// though we explicitly use non-vectored writes in this case for strict parity with the + /// original `write_all` method. + pub fn poll_write_all_buf( + ctx: &mut Context<'_>, + mut writer: Pin<&mut W>, + buf: &mut B, + ) -> Poll> + where + W: AsyncWrite + ?Sized, + B: Buf, + { + while buf.has_remaining() { + let n = ready!(writer.as_mut().poll_write(ctx, buf.chunk()))?; + if n == 0 { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } + buf.advance(n); + } + Poll::Ready(Ok(())) } /* from https://github.com/tokio-rs/tokio/blob/master/tokio-util/src/lib.rs#L177 */ From 5a822047b615f3eb74d8135aa80c49c17dfa3e7f Mon Sep 17 00:00:00 2001 From: Edward Wang Date: Thu, 19 Mar 2026 17:15:53 -0700 Subject: [PATCH 15/52] Add proxy task API for v1 server sessions --- .bleep | 2 +- pingora-core/src/protocols/http/server.rs | 51 ++ pingora-core/src/protocols/http/v1/body.rs | 56 +- pingora-core/src/protocols/http/v1/header.rs | 34 +- pingora-core/src/protocols/http/v1/mod.rs | 105 +++ pingora-core/src/protocols/http/v1/server.rs | 761 ++++++++++++++++--- pingora-proxy/src/lib.rs | 139 ++-- 7 files changed, 972 insertions(+), 176 deletions(-) diff --git a/.bleep b/.bleep index 8ca85d6c..5a60cfc8 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -ade24f55fc0b8c3be1b0a22da73cfe94058f811c \ No newline at end of file +b8823a8f0713f33ec7f83a5c2df8d5491c8a5613 \ No newline at end of file diff --git a/pingora-core/src/protocols/http/server.rs b/pingora-core/src/protocols/http/server.rs index 78852939..65c51723 100644 --- a/pingora-core/src/protocols/http/server.rs +++ b/pingora-core/src/protocols/http/server.rs @@ -811,4 +811,55 @@ impl Session { Self::Custom(_) => None, } } + + /// Check if this session supports the cancel-safe proxy task API. + pub fn supports_proxy_task_api(&self) -> bool { + // only H1 for now + matches!(self, Self::H1(_)) + } + + /// Queue a downstream proxy task for cancel-safe writing. + /// + /// # Panics + /// Panics if called on a session that doesn't support the proxy task API. + /// Check [`supports_proxy_task_api`](Self::supports_proxy_task_api) first, + /// or use `write_response_header()` / `write_response_body()` for other + /// session types. + pub fn send_downstream_proxy_task(&mut self, task: HttpTask) { + match self { + Self::H1(s) => s.send_proxy_task(task), + Self::H2(_) => panic!("H2 proxy task API not yet implemented"), + Self::Subrequest(_) => panic!("Subrequest proxy task API not yet implemented"), + Self::Custom(_) => panic!("Custom proxy task API not yet implemented"), + } + } + + /// Check if there are pending downstream proxy tasks queued for writing. + /// + /// Returns false for sessions that don't support the proxy task API. + pub fn has_pending_downstream_proxy_tasks(&self) -> bool { + match self { + Self::H1(s) => s.has_pending_proxy_tasks(), + Self::H2(_) => false, // TODO: implement for H2 + Self::Subrequest(_) => false, // TODO: implement for subrequests + Self::Custom(_) => false, // TODO: implement for custom + } + } + + /// Write all queued downstream proxy tasks in a cancel-safe manner. + /// Returns `Ok(true)` if this was the end of the response stream. + /// + /// # Panics + /// Panics if called on a session that doesn't support the proxy task API. + /// Check [`supports_proxy_task_api`](Self::supports_proxy_task_api) first, + /// or use `write_response_header()` / `write_response_body()` for other + /// session types. + pub async fn write_downstream_proxy_tasks(&mut self) -> Result { + match self { + Self::H1(s) => s.write_proxy_tasks().await, + Self::H2(_) => panic!("H2 proxy task API not yet implemented"), + Self::Subrequest(_) => panic!("Subrequest proxy task API not yet implemented"), + Self::Custom(_) => panic!("Custom proxy task API not yet implemented"), + } + } } diff --git a/pingora-core/src/protocols/http/v1/body.rs b/pingora-core/src/protocols/http/v1/body.rs index fbed3b11..61872af6 100644 --- a/pingora-core/src/protocols/http/v1/body.rs +++ b/pingora-core/src/protocols/http/v1/body.rs @@ -943,7 +943,6 @@ impl BodyMode { /// Type alias for the chunked encoding buffer chain type ChunkedBuf = bytes::buf::Chain, &'static [u8]>; -#[allow(dead_code)] enum WriteBuf { /// Simple bytes buffer Simple(Bytes), @@ -975,7 +974,6 @@ impl Buf for WriteBuf { } } -#[allow(dead_code)] enum WriteState { /// No write in progress Idle, @@ -1004,7 +1002,6 @@ impl std::fmt::Debug for WriteState { } } -#[allow(dead_code)] enum FinishWriteState { /// No finish task queued NotStarted, @@ -1079,9 +1076,7 @@ impl SendBodyState { pub struct BodyWriter { pub body_mode: BodyMode, // Boxed to reduce inline size. Only used by the cancel-safe proxy task API. - #[allow(dead_code)] send_body_state: Box, - #[allow(dead_code)] send_finish_state: FinishWriteState, } @@ -1313,7 +1308,6 @@ impl BodyWriter { /// /// The timeout, if provided, will be enforced internally across all /// write attempts, even if the write is cancelled and resumed via `tokio::select!`. - #[allow(dead_code)] pub fn send_body_task(&mut self, bytes: Bytes, timeout: Option) { assert!( matches!( @@ -1335,7 +1329,6 @@ impl BodyWriter { /// /// This function can be safely used in a `tokio::select!` loop. /// Returns `Ok(Some(bytes_written))` when complete, `Ok(None)` if no bytes to write. - #[allow(dead_code)] pub async fn write_current_body_task(&mut self, stream: &mut S) -> Result> where S: AsyncWrite + Unpin + Send, @@ -1425,7 +1418,6 @@ impl BodyWriter { /// /// This API is stateful and cancel-safe - use it when you need to finish /// the body in a `tokio::select!` loop or other cancellable context. - #[allow(dead_code)] pub fn send_finish_task(&mut self) { self.send_finish_state = FinishWriteState::Idle; } @@ -1436,7 +1428,6 @@ impl BodyWriter { /// /// This API is stateful - it tracks progress across cancellations and can be /// safely resumed after being dropped mid-execution. - #[allow(dead_code)] pub async fn write_current_finish_task(&mut self, stream: &mut S) -> Result> where S: AsyncWrite + Unpin + Send, @@ -3067,6 +3058,7 @@ mod tests { #[cfg(test)] mod test_body_task_api { use super::*; + use crate::protocols::http::v1::test_util::FlushTrackingMock; use tokio_test::io::Builder; // Cancel-safety tests use tokio::select! to race a short sleep against a mock @@ -3687,6 +3679,7 @@ mod test_body_task_api { // Cancel-safe finish task for chunked encoding: send_finish_task() queues // the terminating chunk, write_current_finish_task() writes it and can be // cancelled and resumed in a select! loop. + // Verifies that the finish flushes the stream exactly once. #[tokio::test(start_paused = true)] async fn cancel_safe_finish_task_chunked() { init_log(); @@ -3694,7 +3687,8 @@ mod test_body_task_api { let data = Bytes::from("hello"); let expected_chunk = b"5\r\nhello\r\n"; - let mut mock_io = Builder::new().write(expected_chunk).build(); + let mock_io = Builder::new().write(expected_chunk).build(); + let (mut flush_mock, flush_count) = FlushTrackingMock::new(mock_io); let mut body_writer = BodyWriter::new(); body_writer.init_chunked(); @@ -3702,19 +3696,27 @@ mod test_body_task_api { // Write body data via task API body_writer.send_body_task(data, None); body_writer - .write_current_body_task(&mut mock_io) + .write_current_body_task(&mut flush_mock) .await .unwrap(); + // Chunked body writes always flush after each chunk + assert_eq!( + FlushTrackingMock::flush_count(&flush_count), + 1, + "Chunked body data write should flush once" + ); + // Queue the finish task body_writer.send_finish_task(); assert!(body_writer.has_pending_finish_task()); // Write the finish in a select! loop with cancellations - let mut mock_io_finish = Builder::new() + let mock_io_finish = Builder::new() .wait(std::time::Duration::from_millis(100)) .write(b"0\r\n\r\n") .build(); + let (mut flush_mock_finish, flush_count_finish) = FlushTrackingMock::new(mock_io_finish); let mut cancel_count = 0; @@ -3727,7 +3729,7 @@ mod test_body_task_api { _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => { cancel_count += 1; } - result = body_writer.write_current_finish_task(&mut mock_io_finish) => { + result = body_writer.write_current_finish_task(&mut flush_mock_finish) => { assert!(result.is_ok()); break; } @@ -3736,33 +3738,53 @@ mod test_body_task_api { assert!(cancel_count > 0, "Should have cancelled at least once"); assert!(matches!(body_writer.body_mode, BodyMode::Complete(_))); + assert_eq!( + FlushTrackingMock::flush_count(&flush_count_finish), + 1, + "Chunked finish should flush exactly once" + ); } // Finish task for content-length is a no-op (no terminating chunk needed), // but it should still transition body_mode to Complete. + // Verifies that no flush occurs (content-length finish has no I/O). #[tokio::test] async fn finish_task_content_length() { init_log(); let data = b"hello"; - let mut mock_io = Builder::new().write(data).build(); + let mock_io = Builder::new().write(data).build(); + let (mut flush_mock, flush_count) = FlushTrackingMock::new(mock_io); let mut body_writer = BodyWriter::new(); body_writer.init_content_length(data.len()); body_writer.send_body_task(Bytes::from_static(data), None); body_writer - .write_current_body_task(&mut mock_io) + .write_current_body_task(&mut flush_mock) .await .unwrap(); + // Content-length body write flushes when all bytes are written + assert_eq!( + FlushTrackingMock::flush_count(&flush_count), + 1, + "Content-length body write should flush once (all bytes written)" + ); + body_writer.send_finish_task(); - let mut mock_io_finish = Builder::new().build(); + let mock_io_finish = Builder::new().build(); + let (mut flush_mock_finish, flush_count_finish) = FlushTrackingMock::new(mock_io_finish); let result = body_writer - .write_current_finish_task(&mut mock_io_finish) + .write_current_finish_task(&mut flush_mock_finish) .await; assert!(result.is_ok()); assert!(matches!(body_writer.body_mode, BodyMode::Complete(_))); + assert_eq!( + FlushTrackingMock::flush_count(&flush_count_finish), + 0, + "Content-length finish should not flush (no I/O needed)" + ); } // Verifies that body_mode byte tracking is correct when writing diff --git a/pingora-core/src/protocols/http/v1/header.rs b/pingora-core/src/protocols/http/v1/header.rs index 39eb9f1b..b6abdb71 100644 --- a/pingora-core/src/protocols/http/v1/header.rs +++ b/pingora-core/src/protocols/http/v1/header.rs @@ -22,7 +22,6 @@ use tokio::io::AsyncWrite; use crate::protocols::l4::stream::async_write_vec::poll_write_all_buf; -#[allow(dead_code)] enum HeaderWriteState { /// No write in progress Idle, @@ -53,7 +52,6 @@ impl std::fmt::Debug for HeaderWriteState { /// /// Tracks the pending header bytes, write progress (idle → writing → flushing → done), /// and an optional timeout that is lazily created on the first `Pending` poll. -#[allow(dead_code)] struct SendHeaderState { /// Serialized header bytes ready to be written pending_header: Option, @@ -84,7 +82,6 @@ impl std::fmt::Debug for SendHeaderState { } impl SendHeaderState { - #[allow(dead_code)] fn new() -> Self { SendHeaderState { pending_header: None, @@ -114,7 +111,6 @@ impl SendHeaderState { /// /// A timeout, if set, is enforced *across* cancellations — the clock keeps /// ticking even when the future is dropped and re-polled. -#[allow(dead_code)] pub struct HeaderWriter { // Boxed to reduce inline size. Only used by the cancel-safe proxy task API. send_header_state: Box, @@ -127,7 +123,6 @@ impl Default for HeaderWriter { } impl HeaderWriter { - #[allow(dead_code)] pub fn new() -> Self { HeaderWriter { send_header_state: Box::new(SendHeaderState::new()), @@ -146,7 +141,6 @@ impl HeaderWriter { /// Queue serialized header bytes as a write task with an optional timeout. /// This is a non-async function that just saves the bytes. /// Call [`write_current_header_task`](Self::write_current_header_task) to actually perform the write. - #[allow(dead_code)] pub fn send_header_task( &mut self, header_bytes: Bytes, @@ -171,7 +165,6 @@ impl HeaderWriter { /// Async function that writes the current queued header task to the stream. /// This function is cancel-safe and can be called in a `tokio::select!` loop. /// Returns `Ok(bytes_written)` when complete, `Ok(0)` if no bytes to write. - #[allow(dead_code)] pub async fn write_current_header_task(&mut self, stream: &mut S) -> Result where S: AsyncWrite + Unpin + Send, @@ -304,6 +297,7 @@ impl HeaderWriter { #[cfg(test)] mod tests { use super::*; + use crate::protocols::http::v1::test_util::FlushTrackingMock; use tokio_test::io::Builder; fn init_log() { @@ -311,18 +305,26 @@ mod tests { } #[tokio::test] - async fn test_simple_header_write() { + async fn test_simple_header_write_no_flush() { init_log(); let header_data = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\n"; - let mut mock_io = Builder::new().write(header_data).build(); + let mock_io = Builder::new().write(header_data).build(); + let (mut flush_mock, flush_count) = FlushTrackingMock::new(mock_io); let mut header_writer = HeaderWriter::new(); header_writer.send_header_task(Bytes::from_static(header_data), false, None); - let result = header_writer.write_current_header_task(&mut mock_io).await; + let result = header_writer + .write_current_header_task(&mut flush_mock) + .await; assert!(result.is_ok()); assert_eq!(result.unwrap(), header_data.len()); + assert_eq!( + FlushTrackingMock::flush_count(&flush_count), + 0, + "should_flush=false should not flush" + ); } #[tokio::test] @@ -330,14 +332,22 @@ mod tests { init_log(); let header_data = b"HTTP/1.1 200 OK\r\n\r\n"; - let mut mock_io = Builder::new().write(header_data).build(); + let mock_io = Builder::new().write(header_data).build(); + let (mut flush_mock, flush_count) = FlushTrackingMock::new(mock_io); let mut header_writer = HeaderWriter::new(); header_writer.send_header_task(Bytes::from_static(header_data), true, None); - let result = header_writer.write_current_header_task(&mut mock_io).await; + let result = header_writer + .write_current_header_task(&mut flush_mock) + .await; assert!(result.is_ok()); assert_eq!(result.unwrap(), header_data.len()); + assert_eq!( + FlushTrackingMock::flush_count(&flush_count), + 1, + "should_flush=true should flush exactly once" + ); } // Uses start_paused for deterministic timer-based cancellation in select! diff --git a/pingora-core/src/protocols/http/v1/mod.rs b/pingora-core/src/protocols/http/v1/mod.rs index 53acaec9..6f085a70 100644 --- a/pingora-core/src/protocols/http/v1/mod.rs +++ b/pingora-core/src/protocols/http/v1/mod.rs @@ -19,3 +19,108 @@ pub mod client; pub mod common; pub(crate) mod header; pub mod server; + +/// Test utilities shared across HTTP/1.x unit tests +#[cfg(test)] +pub(crate) mod test_util { + use std::pin::Pin; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + use std::task::{Context, Poll}; + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + use tokio_test::io::Mock; + + /// A wrapper around [`Mock`] that counts flush calls. + /// + /// `tokio_test::io::Mock`'s `poll_flush` always returns `Ready(Ok(()))`, + /// so we can't detect flush calls via mock alone. This wrapper counts them. + #[derive(Debug)] + pub(crate) struct FlushTrackingMock { + inner: Mock, + flush_count: Arc, + } + + impl FlushTrackingMock { + pub(crate) fn new(mock: Mock) -> (Self, Arc) { + let flush_count = Arc::new(AtomicUsize::new(0)); + ( + FlushTrackingMock { + inner: mock, + flush_count: flush_count.clone(), + }, + flush_count, + ) + } + + pub(crate) fn flush_count(counter: &Arc) -> usize { + counter.load(Ordering::Relaxed) + } + } + + impl AsyncRead for FlushTrackingMock { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.get_mut().inner).poll_read(cx, buf) + } + } + + impl AsyncWrite for FlushTrackingMock { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.get_mut().inner).poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + let result = Pin::new(&mut this.inner).poll_flush(cx); + if let Poll::Ready(Ok(())) = &result { + this.flush_count.fetch_add(1, Ordering::Relaxed); + } + result + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.get_mut().inner).poll_shutdown(cx) + } + } + + // Implement IO-required traits so FlushTrackingMock can be used as Box + // in HttpSession tests (server.rs). + use crate::protocols::{ + raw_connect::ProxyDigest, GetProxyDigest, GetSocketDigest, GetTimingDigest, Peek, Shutdown, + SocketDigest, Ssl, TimingDigest, UniqueID, UniqueIDType, + }; + + #[async_trait::async_trait] + impl Shutdown for FlushTrackingMock { + async fn shutdown(&mut self) -> () {} + } + impl UniqueID for FlushTrackingMock { + fn id(&self) -> UniqueIDType { + 0 + } + } + impl Ssl for FlushTrackingMock {} + impl GetTimingDigest for FlushTrackingMock { + fn get_timing_digest(&self) -> Vec> { + vec![] + } + } + impl GetProxyDigest for FlushTrackingMock { + fn get_proxy_digest(&self) -> Option> { + None + } + } + impl GetSocketDigest for FlushTrackingMock { + fn get_socket_digest(&self) -> Option> { + None + } + } + impl Peek for FlushTrackingMock {} +} diff --git a/pingora-core/src/protocols/http/v1/server.rs b/pingora-core/src/protocols/http/v1/server.rs index 7e648ca5..0cfdb47d 100644 --- a/pingora-core/src/protocols/http/v1/server.rs +++ b/pingora-core/src/protocols/http/v1/server.rs @@ -28,15 +28,48 @@ use pingora_http::{IntoCaseHeaderName, RequestHeader, ResponseHeader}; use pingora_timeout::timeout; use regex::bytes::Regex; use std::any::Any; +use std::collections::VecDeque; use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use super::body::{BodyReader, BodyWriter}; use super::common::*; +use super::header::HeaderWriter; use crate::protocols::http::{body_buffer::FixedBuffer, date, HttpTask}; use crate::protocols::{Digest, SocketAddr, Stream}; use crate::utils::{BufRef, KVRef}; +/// Tracks which writer is currently processing a task. +/// +/// This enables resuming writes after cancellation. Each variant stores the +/// minimal data needed for cleanup after write completes. +#[derive(Debug)] +enum ProxyTaskWriter { + /// Currently writing a header task. + /// Stores: (header for `response_written`, end_stream flag) + WritingHeader(Box, bool), + /// Currently writing a body task (`Body` or `UpgradedBody`). + /// Stores: (end_stream flag) + WritingBody(bool), + /// Currently finishing the body (writing last chunk + flush). + FinishingBody, +} + +/// State for the cancel-safe proxy task write API. +#[derive(Default)] +struct ProxyTaskState { + /// Lazily initialized — `HeaderWriter::new()` heap-allocates. + header_writer: Option, + tasks: VecDeque, + current_writer: Option, +} + +impl ProxyTaskState { + fn header_writer(&mut self) -> &mut HeaderWriter { + self.header_writer.get_or_insert_with(HeaderWriter::new) + } +} + /// The HTTP 1.x server session pub struct HttpSession { underlying_stream: Stream, @@ -52,6 +85,8 @@ pub struct HttpSession { body_reader: BodyReader, /// A state machine to track how to write the response body body_writer: BodyWriter, + /// Cancel-safe proxy task state. + proxy_task_state: ProxyTaskState, /// An internal buffer to buf multiple body writes to reduce the underlying syscalls body_write_buf: BytesMut, /// Track how many application (not on the wire) body bytes already sent @@ -98,6 +133,9 @@ pub struct HttpSession { /// close is tolerated and `read_body_or_idle` stays pending so the proxy can /// finish delivering the upstream response (RFC 9112 Section 9.6). abort_on_close: bool, + /// Whether the cancel-safe proxy task API is enabled for this session. + /// Defaults to false. Can be enabled via [`set_proxy_tasks_enabled`](Self::set_proxy_tasks_enabled). + proxy_tasks_enabled: bool, } impl HttpSession { @@ -120,6 +158,7 @@ impl HttpSession { preread_body: None, body_reader: BodyReader::new(false), body_writer: BodyWriter::new(), + proxy_task_state: ProxyTaskState::default(), body_write_buf: BytesMut::new(), keepalive_timeout: KeepaliveStatus::Off, update_resp_headers: true, @@ -141,6 +180,7 @@ impl HttpSession { connection_user_context: None, half_closed: false, abort_on_close: true, + proxy_tasks_enabled: false, } } @@ -510,101 +550,12 @@ impl HttpSession { /// Write the response header to the client. /// This function can be called more than once to send 1xx informational headers excluding 101. pub async fn write_response_header(&mut self, mut header: Box) -> Result<()> { - if header.status.is_informational() && self.ignore_info_resp(header.status.into()) { - debug!("ignoring informational headers"); + // Prepare header (handle upgrades, set headers, initialize body writer, serialize to bytes) + let Some((write_buf, flush)) = self.prepare_response_header(&mut header)? else { + // Header already sent or should be ignored return Ok(()); - } - - if let Some(resp) = self.response_written.as_ref() { - if !resp.status.is_informational() || self.upgraded { - warn!("Respond header is already sent, cannot send again"); - return Ok(()); - } - } - - // if body unfinished, or request header was not finished reading - if self.close_on_response_before_downstream_finish - && (self.request_header.is_none() || !self.is_body_done()) - { - debug!("set connection close before downstream finish"); - self.set_keepalive(None); - } - - // no need to add these headers to 1xx responses - if !header.status.is_informational() && self.update_resp_headers { - /* update headers */ - header.insert_header(header::DATE, date::get_cached_date())?; - - // TODO: make these lazy static - let connection_value = if self.will_keepalive() { - "keep-alive" - } else { - "close" - }; - header.insert_header(header::CONNECTION, connection_value)?; - } - - if header.status == 101 { - // make sure the connection is closed at the end when 101/upgrade is used - self.set_keepalive(None); - } - - // Allow informational header (excluding 101) to pass through without affecting the state - // of the request - if header.status == 101 || !header.status.is_informational() { - // reset request body to done for incomplete upgrade handshakes - if let Some(upgrade_ok) = self.is_upgrade(&header) { - if upgrade_ok { - debug!("ok upgrade handshake"); - // For ws we use HTTP1_0 do_read_body_until_closed - // - // On ws close the initiator sends a close frame and - // then waits for a response from the peer, once it receives - // a response it closes the conn. After receiving a - // control frame indicating the connection should be closed, - // a peer discards any further data received. - // https://www.rfc-editor.org/rfc/rfc6455#section-1.4 - self.upgraded = true; - // Now that the upgrade was successful, we need to change - // how we interpret the rest of the body as pass-through. - if self.body_reader.need_init() { - self.init_body_reader(); - } else { - // already initialized - // immediately start reading the rest of the body as upgraded - // (in practice most upgraded requests shouldn't have any body) - // - // TODO: https://datatracker.ietf.org/doc/html/rfc9110#name-upgrade - // the most spec-compliant behavior is to switch interpretation - // after sending the former body, - // we immediately switch interpretation to match nginx - self.body_reader.convert_to_close_delimited(); - } - } else { - // this was a request that requested Upgrade, - // but upstream did not comply - debug!("bad upgrade handshake!"); - // continue to read body as-is, this is now just a regular request - } - } - self.init_body_writer(&header); - } - - // Defense-in-depth: if response body is close-delimited, mark session - // as un-reusable - if self.body_writer.is_close_delimited() { - self.set_keepalive(None); - } - - // Don't have to flush response with content length because it is less - // likely to be real time communication. So do flush when - // 1.1xx response: client needs to see it before the rest of response - // 2.No content length: the response could be generated in real time - let flush = header.status.is_informational() - || header.headers.get(header::CONTENT_LENGTH).is_none(); + }; - let mut write_buf = BytesMut::with_capacity(INIT_HEADER_BUF_SIZE); - http_resp_header_to_buf(&header, &mut write_buf).unwrap(); match self.underlying_stream.write_all(&write_buf).await { Ok(()) => { // flush the stream if 1xx header or there is no response body @@ -759,6 +710,117 @@ impl HttpSession { } } + /// Prepare response header for writing: handle upgrades, set headers, initialize body writer. + /// This contains all the synchronous logic that should happen before writing the header. + /// Returns Ok(Some((bytes, should_flush))) if the header should be written, Ok(None) if should skip. + fn prepare_response_header( + &mut self, + header: &mut ResponseHeader, + ) -> Result> { + // Check if we should ignore informational responses + if header.status.is_informational() && self.ignore_info_resp(header.status.into()) { + debug!("ignoring informational headers"); + return Ok(None); + } + + // Check if we already sent a response header + if let Some(ref resp) = self.response_written { + if !resp.status.is_informational() || self.upgraded { + warn!("Respond header is already sent, cannot send again"); + return Ok(None); + } + } + + // if body unfinished, or request header was not finished reading + if self.close_on_response_before_downstream_finish + && (self.request_header.is_none() || !self.is_body_done()) + { + debug!("set connection close before downstream finish"); + self.set_keepalive(None); + } + + // no need to add these headers to 1xx responses + if !header.status.is_informational() && self.update_resp_headers { + /* update headers */ + header.insert_header(header::DATE, date::get_cached_date())?; + + // TODO: make these lazy static + let connection_value = if self.will_keepalive() { + "keep-alive" + } else { + "close" + }; + header.insert_header(header::CONNECTION, connection_value)?; + } + + if header.status == 101 { + // make sure the connection is closed at the end when 101/upgrade is used + self.set_keepalive(None); + } + + // Allow informational header (excluding 101) to pass through without affecting the state + // of the request + if header.status == 101 || !header.status.is_informational() { + // reset request body to done for incomplete upgrade handshakes + if let Some(upgrade_ok) = self.is_upgrade(header) { + if upgrade_ok { + debug!("ok upgrade handshake"); + // For ws we use HTTP1_0 do_read_body_until_closed + // + // On ws close the initiator sends a close frame and + // then waits for a response from the peer, once it receives + // a response it closes the conn. After receiving a + // control frame indicating the connection should be closed, + // a peer discards any further data received. + // https://www.rfc-editor.org/rfc/rfc6455#section-1.4 + self.upgraded = true; + // Now that the upgrade was successful, we need to change + // how we interpret the rest of the body as pass-through. + if self.body_reader.need_init() { + self.init_body_reader(); + } else { + // already initialized + // immediately start reading the rest of the body as upgraded + // (in practice most upgraded requests shouldn't have any body) + // + // TODO: https://datatracker.ietf.org/doc/html/rfc9110#name-upgrade + // the most spec-compliant behavior is to switch interpretation + // after sending the former body, + // we immediately switch interpretation to match nginx + self.body_reader.convert_to_close_delimited(); + } + } else { + // this was a request that requested Upgrade, + // but upstream did not comply + debug!("bad upgrade handshake!"); + // continue to read body as-is, this is now just a regular request + } + } + self.init_body_writer(header); + } + + // Defense-in-depth: if response body is close-delimited, mark session + // as un-reusable + if self.body_writer.is_close_delimited() { + self.set_keepalive(None); + } + + // Serialize header to bytes + let mut write_buf = BytesMut::with_capacity(INIT_HEADER_BUF_SIZE); + http_resp_header_to_buf(header, &mut write_buf) + .map_err(|_| Error::explain(WriteError, "serializing response header"))?; + + // Determine if we should flush + // Don't have to flush response with content length because it is less + // likely to be real time communication. So do flush when + // 1. 1xx response: client needs to see it before the rest of response + // 2. No content length: the response could be generated in real time + let should_flush = header.status.is_informational() + || header.headers.get(header::CONTENT_LENGTH).is_none(); + + Ok(Some((write_buf.freeze(), should_flush))) + } + fn init_body_writer(&mut self, header: &ResponseHeader) { use http::StatusCode; /* the following responses don't have body 204, 304, and HEAD */ @@ -1320,6 +1382,152 @@ impl HttpSession { Ok(end_stream || self.body_writer.finished()) } + /// Queue a proxy task for cancel-safe writing with the current write_timeout. + /// The task will be written when `write_proxy_tasks()` is called. + /// + /// A write canceled mid-operation can be resumed via `write_proxy_tasks()`. + pub fn send_proxy_task(&mut self, task: HttpTask) { + self.proxy_task_state.tasks.push_back(task); + } + + /// Check if there are pending proxy tasks queued for writing. + pub fn has_pending_proxy_tasks(&self) -> bool { + !self.proxy_task_state.tasks.is_empty() + } + + /// Write all queued proxy tasks (response `HttpTask`s from `send_proxy_task`) + /// in a cancel-safe manner. + /// + /// If cancelled mid-write, the next call will resume the in-progress write. + /// + /// Returns `Ok(true)` if this was the end of the response stream. + // Leverages the cancel-safe `HeaderWriter` and `BodyWriter` primitives. + // TODO: we can do the same for the non-cancel-safe APIs. + pub async fn write_proxy_tasks(&mut self) -> Result { + let mut end_stream = false; + + // TODO: buffer body data like response_duplex_vec + loop { + // - Resume any in-progress write + if let Some(ref writer_state) = self.proxy_task_state.current_writer { + match writer_state { + ProxyTaskWriter::WritingHeader(_, _) => { + let _bytes_written = self + .proxy_task_state + .header_writer() + .write_current_header_task(&mut self.underlying_stream) + .await + .map_err(|e| e.into_down())?; + } + ProxyTaskWriter::WritingBody(_) => { + let written = self + .body_writer + .write_current_body_task(&mut self.underlying_stream) + .await + .map_err(|e| e.into_down())?; + if let Some(n) = written { + self.body_bytes_sent += n; + } + } + ProxyTaskWriter::FinishingBody => { + self.body_writer + .write_current_finish_task(&mut self.underlying_stream) + .await + .map_err(|e| e.into_down())?; + } + } + + match self + .proxy_task_state + .current_writer + .take() + .expect("writer state present") + { + ProxyTaskWriter::WritingHeader(header, end) => { + self.response_written = Some(header); + end_stream = end; + } + ProxyTaskWriter::WritingBody(end) => { + end_stream = end; + } + ProxyTaskWriter::FinishingBody => { + end_stream = true; + self.maybe_force_close_body_reader(); + break; // fine to break after finish, no tasks should be queued after + } + } + continue; + } + + // - Send tasks, set state. + // Pop next task + let Some(task) = self.proxy_task_state.tasks.pop_front() else { + if end_stream { + self.body_writer.send_finish_task(); + self.proxy_task_state.current_writer = Some(ProxyTaskWriter::FinishingBody); + continue; + } + break; + }; + + match task { + HttpTask::Header(mut header, end) => { + let Some((write_buf, should_flush)) = + self.prepare_response_header(&mut header)? + else { + end_stream = end; + continue; + }; + // header only responses will want to flush + let flush = should_flush || self.body_writer.finished(); + self.proxy_task_state + .header_writer() + .send_header_task(write_buf, flush, None); + self.proxy_task_state.current_writer = + Some(ProxyTaskWriter::WritingHeader(header, end)); + } + HttpTask::Body(ref data, end) => { + if self.upgraded { + panic!("Unexpected Body task received on upgraded downstream session"); + } + if let Some(d) = data.as_ref() { + if !d.is_empty() { + let body_timeout = self.write_timeout(d.len()); + self.body_writer.send_body_task(d.clone(), body_timeout); + self.proxy_task_state.current_writer = + Some(ProxyTaskWriter::WritingBody(end)); + continue; + } + } + end_stream = end; + } + HttpTask::UpgradedBody(ref data, end) => { + if !self.upgraded { + panic!("Unexpected UpgradedBody task received on un-upgraded downstream session"); + } + if let Some(d) = data.as_ref() { + if !d.is_empty() { + let body_timeout = self.write_timeout(d.len()); + self.body_writer.send_body_task(d.clone(), body_timeout); + self.proxy_task_state.current_writer = + Some(ProxyTaskWriter::WritingBody(end)); + continue; + } + } + end_stream = end; + } + HttpTask::Trailer(_) | HttpTask::Done => { + end_stream = true; + } + HttpTask::Failed(e) => { + return Err(e); + } + } + } + + Ok(end_stream || self.body_writer.finished()) + } + /// Get the reference of the [Stream] that this HTTP session is operating upon. pub fn stream(&self) -> &Stream { &self.underlying_stream @@ -2852,25 +3060,30 @@ mod test_sync { } #[cfg(test)] -mod test_timeouts { +mod test_proxy_tasks { use super::*; + use http::StatusCode; use std::future::IntoFuture; use tokio_test::io::{Builder, Mock}; - /// An upper limit for any read within any test to prevent tests from hanging forever if - /// an internal read call never returns, etc. + fn init_log() { + let _ = env_logger::builder().is_test(true).try_init(); + } + + // An upper limit for any read within any test to prevent tests from hanging forever if + // an internal read call never returns, etc. const TEST_MAX_WAIT_FOR_READ: Duration = Duration::from_secs(3); - /// The duration of 600 seconds is chosen to be "effectively forever" for the purpose of testing + // The duration of 600 seconds is chosen to be "effectively forever" for the purpose of testing const TEST_FOREVER_DURATION: Duration = Duration::from_secs(600); - /// The read_timeout to use, when we want to test that a read operation times out + // The read_timeout to use, when we want to test that a read operation times out const TEST_READ_TIMEOUT: Duration = Duration::from_secs(1); #[derive(Debug)] struct ReadBlockedForeverError; - /// Returns a client stream that will "never" send any bytes / return from a read operation + // Returns a client stream that will "never" send any bytes / return from a read operation fn mocked_blocking_headers_forever_stream() -> Box { Box::new(Builder::new().wait(TEST_FOREVER_DURATION).build()) } @@ -2887,8 +3100,8 @@ mod test_timeouts { ) } - /// Helper function to test a read operation with a tokio timeout - /// to prevent tests from hanging forever in case of a bug + // Helper function to test a read operation with a tokio timeout + // to prevent tests from hanging forever in case of a bug async fn test_read_with_tokio_timeout( read_future: F, ) -> Result>, ReadBlockedForeverError> @@ -2932,6 +3145,352 @@ mod test_timeouts { assert!(res.is_ok()); assert_eq!(res.unwrap().unwrap_err().etype(), &ReadTimedout); } + + #[tokio::test] + async fn test_send_proxy_task_and_write() { + init_log(); + + // We need to know exact bytes that will be written + // "HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello" + let expected_header = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\n"; + let expected_body = b"hello"; + + let mock_io = Builder::new() + .write(expected_header) + .write(expected_body) + .build(); + + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.update_resp_headers = false; // Disable automatic headers + + // Queue header task + let mut header = ResponseHeader::build(StatusCode::OK, Some(5)).unwrap(); + header.insert_header("Content-Length", "5").unwrap(); + http_stream.send_proxy_task(HttpTask::Header(Box::new(header), false)); + + // Queue body task + http_stream.send_proxy_task(HttpTask::Body(Some(Bytes::from("hello")), true)); + + // Write all tasks + let end_stream = http_stream.write_proxy_tasks().await.unwrap(); + assert!(end_stream); + } + + #[tokio::test] + async fn test_proxy_task_with_timeout() { + init_log(); + + let expected_header = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\n"; + let expected_body = b"hello"; + + let mock_io = Builder::new() + .write(expected_header) + .write(expected_body) + .build(); + + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.update_resp_headers = false; + http_stream.write_timeout = Some(Duration::from_secs(1)); // Set write timeout + + // Queue tasks + let mut header = ResponseHeader::build(StatusCode::OK, Some(5)).unwrap(); + header.insert_header("Content-Length", "5").unwrap(); + http_stream.send_proxy_task(HttpTask::Header(Box::new(header), false)); + http_stream.send_proxy_task(HttpTask::Body(Some(Bytes::from("hello")), true)); + + // Verify initial state + assert_eq!( + http_stream.body_bytes_sent(), + 0, + "Should start with 0 bytes sent" + ); + + // Write all tasks with timeout + let end_stream = http_stream.write_proxy_tasks().await.unwrap(); + assert!(end_stream); + + // Verify body bytes were counted correctly (not double counted) + assert_eq!( + http_stream.body_bytes_sent(), + 5, + "Should count exactly 5 bytes (application level), not double counted" + ); + } + + // Test that write_proxy_tasks is cancel-safe: if the future is dropped mid-execution, + // unwritten tasks should remain in the queue. + #[tokio::test] + async fn test_proxy_task_cancel_safety() { + init_log(); + + let expected_header = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n"; + // First chunk: "5\r\nhello\r\n" + let expected_chunk1 = b"5\r\nhello\r\n"; + + // Create a mock IO that will write the header and first chunk, + // but will block indefinitely on the second chunk + let mock_io = Builder::new() + .write(expected_header) + .write(expected_chunk1) + .wait(Duration::from_secs(999)) // This will cause timeout + .build(); + + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.update_resp_headers = false; + http_stream.write_timeout = Some(Duration::from_millis(100)); + + // Queue 3 tasks: header + 2 body chunks + let mut header = ResponseHeader::build(StatusCode::OK, None).unwrap(); + header + .insert_header("Transfer-Encoding", "chunked") + .unwrap(); + http_stream.send_proxy_task(HttpTask::Header(Box::new(header), false)); + http_stream.send_proxy_task(HttpTask::Body(Some(Bytes::from("hello")), false)); + http_stream.send_proxy_task(HttpTask::Body(Some(Bytes::from("world")), true)); + + // Verify we have 3 tasks queued + assert_eq!(http_stream.proxy_task_state.tasks.len(), 3); + + // Try to write all tasks - this should timeout while writing the second body chunk + let result = http_stream.write_proxy_tasks().await; + assert!(result.is_err()); + assert_eq!(result.unwrap_err().etype(), &WriteTimedout); + + // With the refactored cancel-safe design: + // - First task (header) was written successfully and removed from queue + // - Second task (first body "hello") was removed and sent to BodyWriter, write succeeded, state cleared + // - Third task (second body "world") was removed and sent to BodyWriter, timed out mid-write + // - The in-progress write state is tracked in current_writer, NOT in the queue + assert_eq!( + http_stream.proxy_task_state.tasks.len(), + 0, + "Queue should be empty - tasks are owned by writers once sent" + ); + + // The task being written should be tracked in current_writer + assert!( + matches!( + http_stream.proxy_task_state.current_writer, + Some(ProxyTaskWriter::WritingBody(_)) + ), + "Should be mid-write of body task - writer owns the 'world' task state" + ); + + // Verify body_bytes_sent only counts the successfully written "hello" (5 bytes) + // not the timed-out "world" + assert_eq!( + http_stream.body_bytes_sent(), + 5, + "Should only count the 5 bytes from 'hello', not the incomplete 'world' write" + ); + + // On next call to write_proxy_tasks(), Step 1 will resume the "world" write + } + + use crate::protocols::http::v1::test_util::FlushTrackingMock; + + // Test that write_continue_response can be called before write_proxy_tasks + // and both work correctly together. + #[tokio::test] + async fn test_continue_response_before_proxy_tasks() { + init_log(); + + // Expected bytes written: + // 1. 100 Continue response + // 2. 200 OK response header + // 3. Body data + let expected_continue = b"HTTP/1.1 100 Continue\r\n\r\n"; + let expected_header = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\n"; + let expected_body = b"hello"; + + let mock_io = Builder::new() + .write(expected_continue) + .write(expected_header) + .write(expected_body) + .build(); + + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.update_resp_headers = false; // Disable automatic headers + + // First, write the 100 Continue response + http_stream.write_continue_response().await.unwrap(); + + // Verify that 100 Continue was recorded + assert!( + http_stream.response_written().is_some(), + "100 Continue should be recorded in response_written" + ); + assert_eq!( + http_stream.response_written().unwrap().status, + StatusCode::CONTINUE, + "Should have recorded 100 Continue" + ); + + // Now queue the actual response using proxy tasks + let mut header = ResponseHeader::build(StatusCode::OK, Some(5)).unwrap(); + header.insert_header("Content-Length", "5").unwrap(); + http_stream.send_proxy_task(HttpTask::Header(Box::new(header), false)); + http_stream.send_proxy_task(HttpTask::Body(Some(Bytes::from("hello")), true)); + + // Write all proxy tasks + let end_stream = http_stream.write_proxy_tasks().await.unwrap(); + assert!(end_stream, "Should indicate end of stream"); + + // Verify final response is 200 OK, not 100 Continue + assert_eq!( + http_stream.response_written().unwrap().status, + StatusCode::OK, + "Final response should be 200 OK, overwriting 100 Continue" + ); + } + + #[tokio::test] + async fn test_head_response_with_content_length_flushes() { + init_log(); + + // HEAD request line + headers + let request = b"HEAD / HTTP/1.1\r\nHost: example.com\r\n\r\n"; + let expected_header = b"HTTP/1.1 200 OK\r\nContent-Length: 100\r\n\r\n"; + + let mock_io = Builder::new().read(request).write(expected_header).build(); + let (flush_mock, flush_count) = FlushTrackingMock::new(mock_io); + let mut http_stream = HttpSession::new(Box::new(flush_mock)); + http_stream.update_resp_headers = false; + + // Read the HEAD request + http_stream.read_request().await.unwrap(); + assert_eq!(http_stream.get_method(), Some(&Method::HEAD)); + + // Queue header with Content-Length (body will be empty for HEAD) + let mut header = ResponseHeader::build(StatusCode::OK, Some(2)).unwrap(); + header.insert_header("Content-Length", "100").unwrap(); + http_stream.send_proxy_task(HttpTask::Header(Box::new(header), true)); + + let flush_before = FlushTrackingMock::flush_count(&flush_count); + let end_stream = http_stream.write_proxy_tasks().await.unwrap(); + let flush_after = FlushTrackingMock::flush_count(&flush_count); + + assert!(end_stream, "HEAD response should be end of stream"); + assert!( + flush_after > flush_before, + "Should flush after writing HEAD response header with Content-Length \ + (body_writer.finished() is true). Got flush_before={flush_before}, \ + flush_after={flush_after}" + ); + } + + #[tokio::test] + async fn test_204_response_with_content_length_flushes() { + init_log(); + + let request = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"; + let expected_header = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n"; + + let mock_io = Builder::new().read(request).write(expected_header).build(); + let (flush_mock, flush_count) = FlushTrackingMock::new(mock_io); + let mut http_stream = HttpSession::new(Box::new(flush_mock)); + http_stream.update_resp_headers = false; + + http_stream.read_request().await.unwrap(); + + let mut header = ResponseHeader::build(StatusCode::NO_CONTENT, Some(2)).unwrap(); + header.insert_header("Content-Length", "0").unwrap(); + http_stream.send_proxy_task(HttpTask::Header(Box::new(header), true)); + + let flush_before = FlushTrackingMock::flush_count(&flush_count); + let end_stream = http_stream.write_proxy_tasks().await.unwrap(); + let flush_after = FlushTrackingMock::flush_count(&flush_count); + + assert!(end_stream, "204 response should be end of stream"); + assert!( + flush_after > flush_before, + "Should flush after writing 204 response header with Content-Length \ + (body_writer.finished() is true). Got flush_before={flush_before}, \ + flush_after={flush_after}" + ); + } + + #[tokio::test] + #[should_panic( + expected = "Unexpected UpgradedBody task received on un-upgraded downstream session" + )] + async fn test_upgraded_body_on_non_upgraded_session_panics() { + init_log(); + + let request = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"; + let expected_header = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n"; + // UpgradedBody on a non-upgraded session should panic before writing, + // but if the bug exists, BodyWriter would encode it as a chunk: + let expected_chunk = b"5\r\nhello\r\n"; + let expected_finish = b"0\r\n\r\n"; + + let mock_io = Builder::new() + .read(request) + .write(expected_header) + // If the panic check is missing, the body gets written as a chunk + .write(expected_chunk) + .write(expected_finish) + .build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.update_resp_headers = false; + + http_stream.read_request().await.unwrap(); + assert!( + !http_stream.was_upgraded(), + "Session should NOT be upgraded" + ); + + // Queue a normal header + let mut header = ResponseHeader::build(StatusCode::OK, Some(2)).unwrap(); + header + .insert_header("Transfer-Encoding", "chunked") + .unwrap(); + http_stream.send_proxy_task(HttpTask::Header(Box::new(header), false)); + + // Queue an UpgradedBody task on a non-upgraded session — should panic + http_stream.send_proxy_task(HttpTask::UpgradedBody(Some(Bytes::from("hello")), true)); + + // This should panic before/during the body write + let _ = http_stream.write_proxy_tasks().await; + } + + #[tokio::test] + #[should_panic(expected = "Unexpected Body task received on upgraded downstream session")] + async fn test_body_on_upgraded_session_panics() { + init_log(); + + // Upgrade request + let request = + b"GET / HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n"; + // 101 Switching Protocols response + let expected_header = + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n"; + // If the panic check is missing, Body data would be written raw (close-delimited) + let expected_body = b"hello"; + + let mock_io = Builder::new() + .read(request) + .write(expected_header) + .write(expected_body) + .build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.update_resp_headers = false; + + http_stream.read_request().await.unwrap(); + + // Queue 101 header to complete the upgrade + let mut header = ResponseHeader::build(StatusCode::SWITCHING_PROTOCOLS, Some(3)).unwrap(); + header.insert_header("Upgrade", "websocket").unwrap(); + header.insert_header("Connection", "Upgrade").unwrap(); + http_stream.send_proxy_task(HttpTask::Header(Box::new(header), false)); + + // Queue a regular Body task on what will be an upgraded session — should panic + http_stream.send_proxy_task(HttpTask::Body(Some(Bytes::from("hello")), true)); + + // This should panic (after writing the header, session becomes upgraded, + // then the Body task should be rejected) + let _ = http_stream.write_proxy_tasks().await; + } } #[cfg(test)] diff --git a/pingora-proxy/src/lib.rs b/pingora-proxy/src/lib.rs index 52a89cbd..3faad4e4 100644 --- a/pingora-proxy/src/lib.rs +++ b/pingora-proxy/src/lib.rs @@ -587,57 +587,106 @@ impl Session { .await } - pub async fn write_response_tasks(&mut self, mut tasks: Vec) -> Result { - let mut seen_upgraded = self.was_upgraded(); - for task in tasks.iter_mut() { - match task { - HttpTask::Header(resp, end) => { - self.downstream_modules_ctx - .response_header_filter(resp, *end) - .await?; - } - HttpTask::Body(data, end) => { - self.downstream_modules_ctx - .response_body_filter(data, *end)?; - } - HttpTask::UpgradedBody(data, end) => { - seen_upgraded = true; - self.downstream_modules_ctx - .response_body_filter(data, *end)?; - } - HttpTask::Trailer(trailers) => { - if let Some(buf) = self - .downstream_modules_ctx - .response_trailer_filter(trailers)? - { - // Write the trailers into the body if the filter - // returns a buffer. - // - // Note, this will not work if end of stream has already - // been seen or we've written content-length bytes. - // (Trailers should never come after upgraded body) - *task = HttpTask::Body(Some(buf), true); - } - } - HttpTask::Done => { - // `Done` can be sent in certain response paths to mark end - // of response if not already done via trailers or body with - // end flag set. - // If the filter returns body bytes on Done, - // write them into the response. + // Run downstream module response filters on a single task, updating + // `seen_upgraded` to track whether an upgrade has been seen. Used by both + // `send_downstream_proxy_task` and `write_response_tasks`. + async fn downstream_response_task_filter( + &mut self, + task: &mut HttpTask, + seen_upgraded: &mut bool, + ) -> Result<()> { + match task { + HttpTask::Header(resp, end) => { + self.downstream_modules_ctx + .response_header_filter(resp, *end) + .await?; + } + HttpTask::Body(data, end) => { + self.downstream_modules_ctx + .response_body_filter(data, *end)?; + } + HttpTask::UpgradedBody(data, end) => { + *seen_upgraded = true; + self.downstream_modules_ctx + .response_body_filter(data, *end)?; + } + HttpTask::Trailer(trailers) => { + if let Some(buf) = self + .downstream_modules_ctx + .response_trailer_filter(trailers)? + { + // Write the trailers into the body if the filter + // returns a buffer. // // Note, this will not work if end of stream has already // been seen or we've written content-length bytes. - if let Some(buf) = self.downstream_modules_ctx.response_done_filter()? { - if seen_upgraded { - *task = HttpTask::UpgradedBody(Some(buf), true); - } else { - *task = HttpTask::Body(Some(buf), true); - } + // (Trailers should never come after upgraded body) + *task = HttpTask::Body(Some(buf), true); + } + } + HttpTask::Done => { + // `Done` can be sent in certain response paths to mark end + // of response if not already done via trailers or body with + // end flag set. + // If the filter returns body bytes on Done, + // write them into the response. + // + // Note, this will not work if end of stream has already + // been seen or we've written content-length bytes. + if let Some(buf) = self.downstream_modules_ctx.response_done_filter()? { + if *seen_upgraded { + *task = HttpTask::UpgradedBody(Some(buf), true); + } else { + *task = HttpTask::Body(Some(buf), true); } } - _ => { /* Failed */ } } + _ => { /* Failed */ } + } + Ok(()) + } + + /// Queue a downstream proxy task for cancel-safe writing after running + /// downstream module filters. This allows decoupling cache writes from + /// downstream writes. + /// + /// Only works with sessions that support the proxy task API (currently H1). + /// + /// # Panics + /// Panics if the session doesn't support the proxy task API. + /// Use `write_response_tasks()` for sessions that don't support the proxy task API. + pub async fn send_downstream_proxy_task(&mut self, mut task: HttpTask) -> Result<()> { + let mut seen_upgraded = self.was_upgraded(); + self.downstream_response_task_filter(&mut task, &mut seen_upgraded) + .await?; + self.downstream_session.send_downstream_proxy_task(task); + Ok(()) + } + + /// Check if there are pending downstream tasks queued for writing. + /// Used for backpressure - don't queue more cache tasks if we have pending writes. + /// Returns false for sessions that don't support the proxy task API. + pub fn has_pending_downstream_tasks(&self) -> bool { + self.downstream_session.supports_proxy_task_api() + && self.downstream_session.has_pending_downstream_proxy_tasks() + } + + /// Write all queued downstream proxy tasks. This is cancel-safe and can be called + /// in a select! loop while waiting for upstream tasks. + /// For sessions that don't support the proxy task API, this is a no-op. + pub async fn write_downstream_proxy_tasks(&mut self) -> Result { + if self.downstream_session.supports_proxy_task_api() { + self.downstream_session.write_downstream_proxy_tasks().await + } else { + Ok(false) + } + } + + pub async fn write_response_tasks(&mut self, mut tasks: Vec) -> Result { + let mut seen_upgraded = self.was_upgraded(); + for task in tasks.iter_mut() { + self.downstream_response_task_filter(task, &mut seen_upgraded) + .await?; } self.downstream_session.response_duplex_vec(tasks).await } From 8683056e565a7b398083b09b055e888d5b4fbddf Mon Sep 17 00:00:00 2001 From: Edward Wang Date: Sun, 1 Mar 2026 17:43:15 -0800 Subject: [PATCH 16/52] Use proxy task API for cache-served proxy_h1 downstream writes Using the proxy task API allows polling for the upstream rx task at the same time, so that upstream cache writes can continue even while serving downstream. proxy_h2 and h2 downstream (as well as custom) is a todo. --- .bleep | 2 +- pingora-proxy/src/proxy_common.rs | 91 +++++- pingora-proxy/src/proxy_h1.rs | 276 ++++++++++++----- pingora-proxy/src/proxy_h2.rs | 2 + pingora-proxy/tests/test_upstream.rs | 289 ++++++++++++++++++ .../tests/utils/conf/origin/conf/nginx.conf | 1 - pingora-proxy/tests/utils/server_utils.rs | 16 + 7 files changed, 592 insertions(+), 85 deletions(-) diff --git a/.bleep b/.bleep index 5a60cfc8..6cca1e78 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -b8823a8f0713f33ec7f83a5c2df8d5491c8a5613 \ No newline at end of file +033d34cfe2e46f59be14e956033f0b55af3daa45 \ No newline at end of file diff --git a/pingora-proxy/src/proxy_common.rs b/pingora-proxy/src/proxy_common.rs index e1d36f69..6c40760c 100644 --- a/pingora-proxy/src/proxy_common.rs +++ b/pingora-proxy/src/proxy_common.rs @@ -1,3 +1,17 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + /// Possible downstream states during request multiplexing #[derive(Debug, Clone, Copy)] pub(crate) enum DownstreamStateMachine { @@ -36,19 +50,28 @@ impl DownstreamStateMachine { matches!(self, Self::Errored) } - /// Move the state machine to Finished state if `set` is true + /// Move the state machine to Finished state if `set` is true. + /// + /// No-op when the current state is [`Errored`](Self::Errored) — once errored the + /// downstream connection must not be reused, and late upstream chunks arriving + /// via `rx.recv()` must not overwrite that decision. pub fn maybe_finished(&mut self, set: bool) { - if set { + if set && !self.is_errored() { *self = Self::ReadingFinished } } - /// Reset if we should continue reading from the downstream again. - /// Only used with upgraded connections when body mode changes. + /// Reset to [`Reading`](Self::Reading) for upgraded connections when body mode changes. + /// + /// No-op when the current state is [`Errored`](Self::Errored). pub fn reset(&mut self) { - *self = Self::Reading; + if !self.is_errored() { + *self = Self::Reading; + } } + /// Transition to [`Errored`](Self::Errored). This is a terminal state: once entered, + /// no other state transition is permitted and the connection must not be reused. pub fn to_errored(&mut self) { *self = Self::Errored } @@ -97,3 +120,61 @@ impl ResponseStateMachine { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn normal_lifecycle() { + let mut ds = DownstreamStateMachine::new(false); + assert!(ds.is_reading()); + assert!(ds.can_poll()); + assert!(!ds.is_errored()); + + ds.maybe_finished(true); + assert!(!ds.is_reading()); + assert!(ds.is_done()); + assert!(ds.can_poll()); // ReadingFinished still allows polling (for idle) + assert!(!ds.is_errored()); + } + + #[test] + fn errored_is_terminal() { + let mut ds = DownstreamStateMachine::new(false); + ds.to_errored(); + assert!(ds.is_errored()); + assert!(!ds.can_poll()); + assert!(ds.is_done()); + } + + /// `maybe_finished(false)` is always a no-op regardless of state. + #[test] + fn maybe_finished_false_is_noop() { + let mut ds = DownstreamStateMachine::new(false); + ds.to_errored(); + ds.maybe_finished(false); // must not panic + assert!(ds.is_errored()); + assert!(!ds.can_poll()); + } + + /// `maybe_finished(true)` on `Errored` is a no-op — `Errored` is terminal. + #[test] + fn maybe_finished_true_noop_on_errored() { + let mut ds = DownstreamStateMachine::new(false); + ds.to_errored(); + ds.maybe_finished(true); // must not overwrite Errored + assert!(ds.is_errored()); + assert!(!ds.can_poll()); + } + + /// `reset()` on `Errored` is a no-op — `Errored` is terminal. + #[test] + fn reset_noop_on_errored() { + let mut ds = DownstreamStateMachine::new(false); + ds.to_errored(); + ds.reset(); // must not overwrite Errored + assert!(ds.is_errored()); + assert!(!ds.can_poll()); + } +} diff --git a/pingora-proxy/src/proxy_h1.rs b/pingora-proxy/src/proxy_h1.rs index 9f498aa0..dbf6e5ca 100644 --- a/pingora-proxy/src/proxy_h1.rs +++ b/pingora-proxy/src/proxy_h1.rs @@ -267,6 +267,81 @@ where Ok(()) } + #[allow(clippy::too_many_arguments)] + async fn process_upstream_tasks( + &self, + session: &mut Session, + ctx: &mut SV::CTX, + initial_task: HttpTask, + rx: &mut mpsc::Receiver, + serve_from_cache: &mut ServeFromCache, + range_body_filter: &mut proxy_cache::range_filter::RangeBodyFilter, + response_state: &mut ResponseStateMachine, + ) -> Result> + where + SV: ProxyHttp + Send + Sync, + SV::CTX: Send + Sync, + { + if serve_from_cache.should_discard_upstream() { + // just drain, do we need to do anything else? + return Ok(None); + } + + // Batch: pull as many tasks as we can from rx + let mut tasks = Vec::with_capacity(TASK_BUFFER_SIZE); + tasks.push(initial_task); + // tokio::task::unconstrained because now_or_never may yield None when the future is ready + while let Some(maybe_task) = tokio::task::unconstrained(rx.recv()).now_or_never() { + debug!("upstream event now: {:?}", maybe_task); + if let Some(t) = maybe_task { + tasks.push(t); + } else { + break; // upstream closed + } + } + + /* run filters before sending to downstream */ + let mut filtered_tasks = Vec::with_capacity(TASK_BUFFER_SIZE); + for mut t in tasks { + if self.revalidate_or_stale(session, &mut t, ctx).await { + serve_from_cache.enable(); + response_state.enable_cached_response(); + // skip downstream filtering entirely as the 304 will not be sent + break; + } + #[cfg(feature = "adjust_upstream_modules")] + if let HttpTask::Header(header, end_of_stream) = &t { + self.inner + .adjust_upstream_modules(session, header, *end_of_stream, ctx) + .await?; + } + session.upstream_compression.response_filter(&mut t); + let task = self + .h1_response_filter(session, t, ctx, serve_from_cache, range_body_filter, false) + .await?; + if serve_from_cache.is_miss_header() { + response_state.enable_cached_response(); + } + // check error and abort + // otherwise the error is surfaced via write_response_tasks() + if !serve_from_cache.should_send_to_downstream() { + if let HttpTask::Failed(e) = task { + return Err(e); + } + } + filtered_tasks.push(task); + } + + if !serve_from_cache.should_send_to_downstream() { + // TODO: need to derive response_done from filtered_tasks in case downstream failed already + return Ok(None); + } + + let response_done = session.write_response_tasks(filtered_tasks).await?; + + Ok(Some(response_done)) + } + // todo use this function to replace bidirection_1to2() // returns whether this server (downstream) session can be reused async fn proxy_handle_downstream( @@ -329,6 +404,8 @@ where let mut serve_from_cache = proxy_cache::ServeFromCache::new(); let mut range_body_filter = proxy_cache::range_filter::RangeBodyFilter::new(); + let mut next_upstream_task: Option = None; + /* duplex mode without caching * Read body from downstream while reading response from upstream * If response is done, only read body from downstream @@ -424,74 +501,56 @@ where // If tx is closed, the upstream has already finished its job. downstream_state.maybe_finished(tx.is_closed()); debug!("waiting for permit {send_permit:?}, upstream closed {}", tx.is_closed()); - /* No permit, wait on more capacity to avoid starving. + /* No permit, wait on more capacity to avoid starving. * Otherwise this select only blocks on rx, which might send no data * before the entire body is uploaded. * once more capacity arrives we just loop back */ }, - task = rx.recv(), if !response_state.upstream_done() => { - debug!("upstream event: {:?}", task); + // Handle buffered upstream task from previous iteration + task = async { next_upstream_task.take() }, if next_upstream_task.is_some() => { + debug!("buffered upstream event: {:?}", task); if let Some(t) = task { - if serve_from_cache.should_discard_upstream() { - // just drain, do we need to do anything else? - continue; - } - // pull as many tasks as we can - let mut tasks = Vec::with_capacity(TASK_BUFFER_SIZE); - tasks.push(t); - // tokio::task::unconstrained because now_or_never may yield None when the future is ready - while let Some(maybe_task) = tokio::task::unconstrained(rx.recv()).now_or_never() { - debug!("upstream event now: {:?}", maybe_task); - if let Some(t) = maybe_task { - tasks.push(t); - } else { - break; // upstream closed - } - } - - /* run filters before sending to downstream */ - let mut filtered_tasks = Vec::with_capacity(TASK_BUFFER_SIZE); - for mut t in tasks { - if self.revalidate_or_stale(session, &mut t, ctx).await { - serve_from_cache.enable(); - response_state.enable_cached_response(); - // skip downstream filtering entirely as the 304 will not be sent - break; - } - #[cfg(feature = "adjust_upstream_modules")] - if let HttpTask::Header(header, end_of_stream) = &t { - self.inner - .adjust_upstream_modules(session, header, *end_of_stream, ctx) - .await?; - } - session.upstream_compression.response_filter(&mut t); - let task = self.h1_response_filter(session, t, ctx, - &mut serve_from_cache, - &mut range_body_filter, false).await?; - if serve_from_cache.is_miss_header() { - response_state.enable_cached_response(); - } - // check error and abort - // otherwise the error is surfaced via write_response_tasks() - if !serve_from_cache.should_send_to_downstream() { - if let HttpTask::Failed(e) = task { - return Err(e); - } - } - filtered_tasks.push(task); - } - - if !serve_from_cache.should_send_to_downstream() { - // TODO: need to derive response_done from filtered_tasks in case downstream failed already + let Some(response_done) = self.process_upstream_tasks( + session, + ctx, + t, + &mut rx, + &mut serve_from_cache, + &mut range_body_filter, + &mut response_state, + ).await? else { + // nothing sent downstream e.g. serve_from_cache continue; - } + }; + response_state.maybe_set_upstream_done(response_done); + // unsuccessful upgrade response may force the request done + downstream_state.maybe_finished(session.is_body_done()); + } else { + debug!("empty upstream event"); + response_state.maybe_set_upstream_done(true); + } + }, - // set to downstream + task = rx.recv(), if !response_state.upstream_done() && next_upstream_task.is_none() => { + debug!("upstream event: {:?}", task); + if let Some(t) = task { let upgraded = session.was_upgraded(); - let response_done = session.write_response_tasks(filtered_tasks).await?; + let Some(response_done) = self.process_upstream_tasks( + session, + ctx, + t, + &mut rx, + &mut serve_from_cache, + &mut range_body_filter, + &mut response_state, + ).await? else { + // nothing sent downstream e.g. serve_from_cache + continue; + }; if !upgraded && session.was_upgraded() && downstream_state.can_poll() { + // TODO: write can happen async now // just upgraded, the downstream state should be reset to continue to // poll body trace!("reset downstream state on upgrade"); @@ -508,35 +567,96 @@ where }, task = serve_from_cache.next_http_task(&mut session.cache, &mut range_body_filter, upgraded), - if !response_state.cached_done() && !downstream_state.is_errored() && serve_from_cache.is_on() => { + if !response_state.cached_done() + && !downstream_state.is_errored() + && serve_from_cache.is_on() + && !session.has_pending_downstream_tasks() => { // backpressure: don't queue if pending writes let task = self.h1_response_filter(session, task?, ctx, &mut serve_from_cache, &mut range_body_filter, true).await?; debug!("serve_from_cache task {task:?}"); - match session.write_response_tasks(vec![task]).await { - Ok(b) => response_state.maybe_set_cache_done(b), - Err(e) => if serve_from_cache.is_miss() { - // give up writing to downstream but wait for upstream cache write to finish - downstream_state.to_errored(); - response_state.maybe_set_cache_done(true); - warn!( - "Downstream Error ignored during caching: {}, {}", - e, - self.inner.request_summary(session, ctx) - ); - // This will not be treated as a final error, but we should signal to - // downstream session regardless - session.downstream_session.on_proxy_failure(e); - continue; - } else { - return Err(e); + if session.downstream_session.supports_proxy_task_api() { + session.send_downstream_proxy_task(task).await?; + } else { + match session.write_response_tasks(vec![task]).await { + Ok(b) => response_state.maybe_set_cache_done(b), + Err(e) => if serve_from_cache.is_miss() { + // give up writing to downstream but wait for upstream cache write to finish + downstream_state.to_errored(); + response_state.maybe_set_cache_done(true); + warn!( + "Downstream Error ignored during caching: {}, {}", + e, + self.inner.request_summary(session, ctx) + ); + // This will not be treated as a final error, but we should signal to + // downstream session regardless + session.downstream_session.on_proxy_failure(e); + continue; + } else { + return Err(e); + } + } + if response_state.cached_done() { + if let Err(e) = session.cache.finish_hit_handler().await { + warn!("Error during finish_hit_handler: {}", e); + } } } - if response_state.cached_done() { - if let Err(e) = session.cache.finish_hit_handler().await { - warn!("Error during finish_hit_handler: {}", e); + } + + // Write queued downstream proxy tasks while also polling for upstream tasks. + // This allows cache writes to continue even when downstream is stalled. + // + // "Gate" branch: ready(()) resolves immediately, so the guard controls + // whether we enter. This is not a busy-loop because every path through + // the inner select either (a) drains all pending tasks via + // write_downstream_proxy_tasks (making the guard false), (b) stores an + // upstream task in next_upstream_task (making the guard false), or + // (c) blocks on real I/O inside the nested select. + _ = std::future::ready(()), if session.has_pending_downstream_tasks() && next_upstream_task.is_none() => { + tokio::select! { + // Try to write downstream proxy tasks (cancel-safe) + write_result = session.write_downstream_proxy_tasks() => { + match write_result { + Ok(end) => { + response_state.maybe_set_cache_done(end); + if response_state.cached_done() { + if let Err(e) = session.cache.finish_hit_handler().await { + warn!("Error during finish_hit_handler: {}", e); + } + } + } + Err(e) => if serve_from_cache.is_miss() { + // give up writing to downstream but wait for upstream cache write to finish + downstream_state.to_errored(); + response_state.maybe_set_cache_done(true); + warn!( + "Downstream write error ignored during caching: {}, {}", + e, + self.inner.request_summary(session, ctx) + ); + // This will not be treated as a final error, but we should signal to + // downstream session regardless + session.downstream_session.on_proxy_failure(e); + } else { + return Err(e); + } + } + } + + // Also poll for upstream tasks - if we get one, cancel the write and handle it. + // Only poll if there is no buffered task already waiting to be processed. + upstream_task = rx.recv(), if !response_state.upstream_done() && serve_from_cache.is_on() && next_upstream_task.is_none() => { + if let Some(t) = upstream_task { + // Store this upstream task to be processed next iteration + next_upstream_task = Some(t); + continue; + } else { + response_state.maybe_set_upstream_done(true); + } } } } diff --git a/pingora-proxy/src/proxy_h2.rs b/pingora-proxy/src/proxy_h2.rs index acf61f07..97b4fb64 100644 --- a/pingora-proxy/src/proxy_h2.rs +++ b/pingora-proxy/src/proxy_h2.rs @@ -444,6 +444,7 @@ where continue; } + // TODO: If downstream supports proxy task API, should use send_downstream_proxy_task() let response_done = session.write_response_tasks(filtered_tasks).await?; if session.was_upgraded() { // it is very weird if the downstream session decides to upgrade @@ -464,6 +465,7 @@ where &mut range_body_filter, true).await?; debug!("serve_from_cache task {task:?}"); + // TODO: If downstream supports proxy task API, should use send_downstream_proxy_task() match session.write_response_tasks(vec![task]).await { Ok(b) => response_state.maybe_set_cache_done(b), Err(e) => if serve_from_cache.is_miss() { diff --git a/pingora-proxy/tests/test_upstream.rs b/pingora-proxy/tests/test_upstream.rs index 9ae4511e..eeafcda9 100644 --- a/pingora-proxy/tests/test_upstream.rs +++ b/pingora-proxy/tests/test_upstream.rs @@ -2913,6 +2913,154 @@ mod test_cache { assert_eq!(res.text().await.unwrap(), "hello world"); } + #[tokio::test] + async fn test_caching_when_downstream_stalls() { + use std::net::ToSocketAddrs; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpStream; + + init(); + let url = "http://127.0.0.1:6148/unique/test_caching_when_downstream_stalls/download/"; + + // Connection 1: read 10KiB then stall, holding the cache lock while + // the proxy populates cache from upstream. + let slow_task = tokio::spawn(async move { + let addr = "127.0.0.1:6148".to_socket_addrs().unwrap().next().unwrap(); + let mut stream = TcpStream::connect(&addr).await.unwrap(); + + let request = concat!( + "GET /unique/test_caching_when_downstream_stalls/download/ HTTP/1.1\r\n", + "Host: 127.0.0.1:6148\r\n", + "x-lock: true\r\n", + "x-set-cache-control: public, max-age=60\r\n", + "\r\n", + ); + stream.write_all(request.as_bytes()).await.unwrap(); + + let mut buf = [0; 10 * 1024]; + let mut b = &mut buf[..]; + while !b.is_empty() { + let n = stream.read(b).await.unwrap(); + b = &mut b[n..] + } + + // Hold the stalled connection open long enough + sleep(Duration::from_secs(10)).await; + }); + + // Give connection 1 time to acquire the cache lock + sleep(Duration::from_secs(1)).await; + + // Connection 2: should get a cache hit once the proxy finishes + // populating cache from upstream (independent of stall). + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(url) + .header("x-lock", "true") + .header("x-set-cache-control", "public, max-age=60") + .timeout(Duration::from_secs(8)) + .send() + .await + .unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + let headers = res.headers(); + assert_eq!(headers["x-cache-status"], "hit"); + + // If the cache was populated fast enough (before connection 2 arrived), + // there is no lock contention and x-cache-lock-time-ms is absent. + // If there was contention, the wait should be short. + if let Some(lock_ms) = headers.get("x-cache-lock-time-ms") { + let ms: u64 = lock_ms.to_str().unwrap().parse().unwrap(); + assert!( + ms < 2000, + "lock wait {ms}ms should be well under the 2s timeout" + ); + } + + assert_eq!( + res.text().await.unwrap(), + String::from("A").repeat(4 * 1024 * 1024) + ); + + let elapsed = start.elapsed(); + assert!( + elapsed < Duration::from_secs(5), + "second request took {elapsed:?}, should be fast" + ); + + // Don't wait for the slow connection + slow_task.abort(); + } + + // Same as test_caching_when_downstream_stalls but the proxy connects + // to the origin over H2 (via the x-h2 header). + // + // Ignored until proxy_h2 gets the proxy task API. + #[tokio::test] + #[ignore] + async fn test_caching_h2_upstream_when_downstream_stalls() { + use std::net::ToSocketAddrs; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpStream; + + init(); + let url = "http://127.0.0.1:6148/unique/test_caching_h2_upstream_when_downstream_stalls/download/"; + + let slow_task = tokio::spawn(async move { + let addr = "127.0.0.1:6148".to_socket_addrs().unwrap().next().unwrap(); + let mut stream = TcpStream::connect(&addr).await.unwrap(); + + let request = concat!( + "GET /unique/test_caching_h2_upstream_when_downstream_stalls/download/ HTTP/1.1\r\n", + "Host: 127.0.0.1:6148\r\n", + "x-h2: true\r\n", + "x-lock: true\r\n", + "x-set-cache-control: public, max-age=60\r\n", + "\r\n", + ); + stream.write_all(request.as_bytes()).await.unwrap(); + + let mut buf = [0; 10 * 1024]; + let mut b = &mut buf[..]; + while !b.is_empty() { + let n = stream.read(b).await.unwrap(); + b = &mut b[n..] + } + + sleep(Duration::from_secs(10)).await; + }); + + sleep(Duration::from_secs(1)).await; + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(url) + .header("x-h2", "true") + .header("x-lock", "true") + .header("x-set-cache-control", "public, max-age=60") + .timeout(Duration::from_secs(8)) + .send() + .await + .unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + let headers = res.headers(); + assert_eq!(headers["x-cache-status"], "hit"); + assert_eq!( + res.text().await.unwrap(), + String::from("A").repeat(4 * 1024 * 1024) + ); + + let elapsed = start.elapsed(); + assert!( + elapsed < Duration::from_secs(5), + "second request took {elapsed:?}, should be fast (upstream-speed-bound)" + ); + + slow_task.abort(); + } + async fn send_vary_req_with_headers_with_dups( url: &str, vary_field: &str, @@ -3536,4 +3684,145 @@ mod test_cache { assert_eq!(headers["x-cache-status"], "hit"); assert_eq!(res.text().await.unwrap(), "hello world"); } + + // Ignored until H2 downstream gets the proxy task API + // (write_response_tasks blocks on flow control today). + // multi_thread needed for h2 connection driver tasks. + #[tokio::test(flavor = "multi_thread")] + #[ignore] + async fn test_cache_h2_downstream_stalls() { + init(); + + use h2::client; + use http::Request; + use tokio::net::TcpStream; + use tokio::time::{timeout, Duration}; + + // Step 1: Connection 1 - Open h2 connection to h2c cache proxy (port 6154) and STALL + let tcp1 = TcpStream::connect("127.0.0.1:6154").await.unwrap(); + let (mut h2_client1, h2_conn1) = client::handshake(tcp1).await.unwrap(); + + tokio::spawn(async move { + if let Err(e) = h2_conn1.await { + eprintln!("H2 connection 1 error: {:?}", e); + } + }); + + // Request the cached resource on connection 1 + let request1 = Request::builder() + .uri("http://127.0.0.1/unique/test_h2_stall/download/") + .body(()) + .unwrap(); + + let (response1, _) = h2_client1.send_request(request1, true).unwrap(); + let response1 = response1.await.unwrap(); + assert_eq!(response1.status(), 200); + assert_eq!(response1.headers()["x-cache-status"], "miss"); + + let mut body1 = response1.into_body(); + + // Read first chunk but don't release flow control to stall connection 1 + let first_chunk = body1.data().await.unwrap().unwrap(); + assert!(!first_chunk.is_empty()); + + // Connection 2 - While conn 1 is stalled, try to get the same cached resource + let tcp2 = TcpStream::connect("127.0.0.1:6154").await.unwrap(); + let (mut h2_client2, h2_conn2) = client::handshake(tcp2).await.unwrap(); + + tokio::spawn(async move { + if let Err(e) = h2_conn2.await { + eprintln!("H2 connection 2 error: {:?}", e); + } + }); + + let request2 = Request::builder() + .uri("http://127.0.0.1/unique/test_h2_stall/download/") + .body(()) + .unwrap(); + + let (response2, _) = h2_client2.send_request(request2, true).unwrap(); + + // Try to read, proxy should not be blocked + let response2 = match timeout(Duration::from_secs(5), response2).await { + Ok(Ok(resp)) => resp, + Ok(Err(e)) => panic!("Connection 2 failed: {:?}", e), + Err(_) => panic!("Connection 2 timed out - proxy blocked without proxy task API!"), + }; + + assert_eq!(response2.status(), 200); + assert_eq!(response2.headers()["x-cache-status"], "hit"); + + // Read full response from connection 2 + let mut body2 = response2.into_body(); + let mut received2 = Vec::new(); + while let Some(Ok(chunk)) = timeout(Duration::from_secs(5), body2.data()) + .await + .expect("should not time out waiting for data") + { + let len = chunk.len(); + received2.extend_from_slice(&chunk); + body2.flow_control().release_capacity(len).unwrap(); + } + + assert_eq!( + received2.len(), + 4 * 1024 * 1024, + "Connection 2 should receive full cached response" + ); + + // Clean up: unstall connection 1 + body1 + .flow_control() + .release_capacity(first_chunk.len()) + .unwrap(); + } + + // Test cache population from H2 upstream origin with H1 downstream. + #[tokio::test] + async fn test_cache_upstream_h2_downstream_h1() { + init(); + + let test_url = "http://127.0.0.1:6148/unique/test_h2_upstream/download/"; + + // Step 1: Populate cache from H2 origin (cache miss) + let client = reqwest::Client::new(); + let res = client + .get(test_url) + .header("x-h2", "true") + .header("x-lock", "true") + .header("x-set-cache-control", "public, max-age=60") + .send() + .await + .unwrap(); + + assert_eq!(res.status(), 200); + assert_eq!(res.headers()["x-cache-status"], "miss"); + assert_eq!(res.headers()["origin-http2"], "h2c"); + + let body = res.bytes().await.unwrap(); + assert_eq!( + body.len(), + 4 * 1024 * 1024, + "Should receive full 4MB response" + ); + + // Step 2: Request again and verify cache hit + let res = client + .get(test_url) + .header("x-h2", "true") + .header("x-set-cache-control", "public, max-age=60") + .send() + .await + .unwrap(); + + assert_eq!(res.status(), 200); + assert_eq!(res.headers()["x-cache-status"], "hit"); + + let body = res.bytes().await.unwrap(); + assert_eq!( + body.len(), + 4 * 1024 * 1024, + "Should receive full 4MB from cache" + ); + } } diff --git a/pingora-proxy/tests/utils/conf/origin/conf/nginx.conf b/pingora-proxy/tests/utils/conf/origin/conf/nginx.conf index f19c974c..969695eb 100644 --- a/pingora-proxy/tests/utils/conf/origin/conf/nginx.conf +++ b/pingora-proxy/tests/utils/conf/origin/conf/nginx.conf @@ -311,7 +311,6 @@ http { location /download/ { content_by_lua_block { - ngx.req.read_body() local body = string.rep("A", 4194304) ngx.header["Content-Length"] = #body ngx.print(body) diff --git a/pingora-proxy/tests/utils/server_utils.rs b/pingora-proxy/tests/utils/server_utils.rs index 9361182e..37423707 100644 --- a/pingora-proxy/tests/utils/server_utils.rs +++ b/pingora-proxy/tests/utils/server_utils.rs @@ -658,6 +658,12 @@ impl ProxyHttp for ExampleProxyCache { upstream_response.remove_header(&CONTENT_LENGTH); upstream_response.remove_header(&TRANSFER_ENCODING); } + // Allow tests to inject Cache-Control into the upstream response + if let Some(cc) = session.req_header().headers.get("x-set-cache-control") { + upstream_response + .insert_header(http::header::CACHE_CONTROL, cc) + .unwrap(); + } Ok(()) } @@ -823,6 +829,15 @@ fn test_main() { pingora_proxy::http_proxy_service(&my_server.configuration, ExampleProxyCache {}); proxy_service_cache.add_tcp("0.0.0.0:6148"); + // H2C-enabled cache proxy on port 6154 + let mut proxy_service_cache_h2c = + pingora_proxy::http_proxy_service(&my_server.configuration, ExampleProxyCache {}); + let cache_h2c_logic = proxy_service_cache_h2c.app_logic_mut().unwrap(); + let mut cache_h2c_options = HttpServerOptions::default(); + cache_h2c_options.h2c = true; + cache_h2c_logic.server_options = Some(cache_h2c_options); + proxy_service_cache_h2c.add_tcp("0.0.0.0:6154"); + #[cfg(feature = "any_tls")] { let cert_path = format!("{}/tests/keys/server.crt", env!("CARGO_MANIFEST_DIR")); @@ -839,6 +854,7 @@ fn test_main() { Box::new(proxy_service_http), Box::new(proxy_service_http_connect), Box::new(proxy_service_cache), + Box::new(proxy_service_cache_h2c), ]; if let Some(proxy_service_https) = proxy_service_https_opt { From ce16618b6c84625125c93769f5351e98427480ea Mon Sep 17 00:00:00 2001 From: Edward Wang Date: Thu, 19 Mar 2026 21:35:38 -0700 Subject: [PATCH 17/52] Add per-session toggle for the proxy task API --- .bleep | 2 +- pingora-core/src/protocols/http/server.rs | 16 ++++++++++++++-- pingora-core/src/protocols/http/v1/server.rs | 10 ++++++++++ pingora-proxy/src/lib.rs | 9 +++++++++ pingora-proxy/tests/utils/server_utils.rs | 15 +++++++++++++-- 5 files changed, 47 insertions(+), 5 deletions(-) diff --git a/.bleep b/.bleep index 6cca1e78..45ba67d2 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -033d34cfe2e46f59be14e956033f0b55af3daa45 \ No newline at end of file +5a6f64463ffe8f272e2ad8e9474ea2a25e863de3 \ No newline at end of file diff --git a/pingora-core/src/protocols/http/server.rs b/pingora-core/src/protocols/http/server.rs index 65c51723..438f3cb0 100644 --- a/pingora-core/src/protocols/http/server.rs +++ b/pingora-core/src/protocols/http/server.rs @@ -813,9 +813,21 @@ impl Session { } /// Check if this session supports the cancel-safe proxy task API. + /// + /// For HTTP/1.x, this can be toggled per-session via + /// [`set_proxy_tasks_enabled`](Self::set_proxy_tasks_enabled). pub fn supports_proxy_task_api(&self) -> bool { - // only H1 for now - matches!(self, Self::H1(_)) + match self { + Self::H1(s) => s.proxy_tasks_enabled(), + _ => false, + } + } + + /// Enable or disable the cancel-safe proxy task API for this session. + pub fn set_proxy_tasks_enabled(&mut self, enabled: bool) { + if let Self::H1(s) = self { + s.set_proxy_tasks_enabled(enabled); + } } /// Queue a downstream proxy task for cancel-safe writing. diff --git a/pingora-core/src/protocols/http/v1/server.rs b/pingora-core/src/protocols/http/v1/server.rs index 0cfdb47d..d80fe3ef 100644 --- a/pingora-core/src/protocols/http/v1/server.rs +++ b/pingora-core/src/protocols/http/v1/server.rs @@ -876,6 +876,16 @@ impl HttpSession { } } + /// Whether the cancel-safe proxy task API is enabled for this session. + pub fn proxy_tasks_enabled(&self) -> bool { + self.proxy_tasks_enabled + } + + /// Enable or disable the cancel-safe proxy task API for this session. + pub fn set_proxy_tasks_enabled(&mut self, enabled: bool) { + self.proxy_tasks_enabled = enabled; + } + async fn do_write_body_buf(&mut self) -> Result> { // Don't flush empty chunks, they are considered end of body for chunks if self.body_write_buf.is_empty() { diff --git a/pingora-proxy/src/lib.rs b/pingora-proxy/src/lib.rs index 3faad4e4..3a50f704 100644 --- a/pingora-proxy/src/lib.rs +++ b/pingora-proxy/src/lib.rs @@ -663,6 +663,15 @@ impl Session { Ok(()) } + /// Enable or disable the cancel-safe proxy task API for this session. + /// + /// When disabled, the proxy falls back to the blocking `write_response_tasks` + /// path. This can be called from request filters to opt out on a per-request + /// basis. + pub fn set_proxy_tasks_enabled(&mut self, enabled: bool) { + self.downstream_session.set_proxy_tasks_enabled(enabled); + } + /// Check if there are pending downstream tasks queued for writing. /// Used for backpressure - don't queue more cache tasks if we have pending writes. /// Returns false for sessions that don't support the proxy task API. diff --git a/pingora-proxy/tests/utils/server_utils.rs b/pingora-proxy/tests/utils/server_utils.rs index 37423707..0dccb6dd 100644 --- a/pingora-proxy/tests/utils/server_utils.rs +++ b/pingora-proxy/tests/utils/server_utils.rs @@ -253,8 +253,19 @@ impl ProxyHttp for ExampleProxyHttp { session: &mut Session, _ctx: &mut Self::CTX, ) -> Result<()> { - let req = session.req_header(); - let downstream_compression = req.headers.get("x-downstream-compression").is_some(); + let proxy_tasks_enabled = session + .req_header() + .headers + .get("x-proxy-tasks-enabled") + .is_some(); + if proxy_tasks_enabled { + session.downstream_session.set_proxy_tasks_enabled(true); + } + let downstream_compression = session + .req_header() + .headers + .get("x-downstream-compression") + .is_some(); if downstream_compression { session .downstream_modules_ctx From 969eb67d1bba3a012cfd5d4a0f12c06070ebaade Mon Sep 17 00:00:00 2001 From: Edward Wang Date: Thu, 19 Mar 2026 21:38:23 -0700 Subject: [PATCH 18/52] Use proxy task API in proxy_h2 and proxy_custom for cache-served downstream writes --- .bleep | 2 +- pingora-core/src/protocols/http/v1/server.rs | 2 +- pingora-proxy/src/proxy_custom.rs | 270 +++++++++++++----- pingora-proxy/src/proxy_h2.rs | 277 +++++++++++++------ pingora-proxy/tests/test_upstream.rs | 2 - 5 files changed, 400 insertions(+), 153 deletions(-) diff --git a/.bleep b/.bleep index 45ba67d2..28024185 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -5a6f64463ffe8f272e2ad8e9474ea2a25e863de3 \ No newline at end of file +855ad50aae6d7bc16238e8dbd1a21fc9dcc5cab9 \ No newline at end of file diff --git a/pingora-core/src/protocols/http/v1/server.rs b/pingora-core/src/protocols/http/v1/server.rs index d80fe3ef..f4465265 100644 --- a/pingora-core/src/protocols/http/v1/server.rs +++ b/pingora-core/src/protocols/http/v1/server.rs @@ -1402,7 +1402,7 @@ impl HttpSession { /// Check if there are pending proxy tasks queued for writing. pub fn has_pending_proxy_tasks(&self) -> bool { - !self.proxy_task_state.tasks.is_empty() + self.proxy_task_state.current_writer.is_some() || !self.proxy_task_state.tasks.is_empty() } /// Write all queued proxy tasks (response `HttpTask`s from `send_proxy_task`) diff --git a/pingora-proxy/src/proxy_custom.rs b/pingora-proxy/src/proxy_custom.rs index b571b3ce..b7ee1d50 100644 --- a/pingora-proxy/src/proxy_custom.rs +++ b/pingora-proxy/src/proxy_custom.rs @@ -257,7 +257,88 @@ where } } - // returns whether server (downstream) session can be reused + #[allow(clippy::too_many_arguments)] + async fn process_upstream_tasks_custom( + &self, + session: &mut Session, + ctx: &mut SV::CTX, + initial_task: HttpTask, + rx: &mut mpsc::Receiver, + serve_from_cache: &mut ServeFromCache, + range_body_filter: &mut proxy_cache::range_filter::RangeBodyFilter, + response_state: &mut ResponseStateMachine, + ) -> Result> + where + SV: ProxyHttp + Send + Sync, + SV::CTX: Send + Sync, + { + if serve_from_cache.should_discard_upstream() { + // just drain, do we need to do anything else? + return Ok(None); + } + + // Batch: pull as many tasks as we can from rx + let mut tasks = Vec::with_capacity(TASK_BUFFER_SIZE); + tasks.push(initial_task); + while let Ok(task) = rx.try_recv() { + tasks.push(task); + } + + /* run filters before sending to downstream */ + let mut filtered_tasks = Vec::with_capacity(TASK_BUFFER_SIZE); + for mut t in tasks { + if self.revalidate_or_stale(session, &mut t, ctx).await { + serve_from_cache.enable(); + response_state.enable_cached_response(); + // skip downstream filtering entirely as the 304 will not be sent + break; + } + #[cfg(feature = "adjust_upstream_modules")] + if let HttpTask::Header(header, end_of_stream) = &t { + self.inner + .adjust_upstream_modules(session, header, *end_of_stream, ctx) + .await?; + } + session.upstream_compression.response_filter(&mut t); + // check error and abort + // otherwise the error is surfaced via write_response_tasks() + if !serve_from_cache.should_send_to_downstream() { + if let HttpTask::Failed(e) = t { + return Err(e); + } + } + filtered_tasks.push( + self.custom_response_filter( + session, + t, + ctx, + serve_from_cache, + range_body_filter, + false, + ) + .await?, + ); + if serve_from_cache.is_miss_header() { + response_state.enable_cached_response(); + } + } + + if !serve_from_cache.should_send_to_downstream() { + // TODO: need to derive response_done from filtered_tasks in case downstream failed already + return Ok(None); + } + + let response_done = session.write_response_tasks(filtered_tasks).await?; + + Ok(Some(response_done)) + } + + // TODO: pre-existing inconsistency with proxy_h1/proxy_h2 to address in a follow-up: + // upstream task rx.recv() branch is missing + // downstream_state.maybe_finished(session.is_body_done()) after processing. proxy_h1 has + // this because upgrade responses can force the body done — since custom upstreams can + // serve H1 downstreams that support upgrades, the same may be needed here. + // Returns whether server (downstream) session can be reused #[allow(clippy::too_many_arguments)] async fn custom_bidirection_down_to_up( &self, @@ -303,6 +384,8 @@ where let mut serve_from_cache = ServeFromCache::new(); let mut range_body_filter = proxy_cache::range_filter::RangeBodyFilter::new(); + let mut next_upstream_task: Option = None; + let mut upstream_custom = true; let mut downstream_custom = true; @@ -361,99 +444,142 @@ where }; }, - task = rx.recv(), if !response_state.upstream_done() => { - debug!("upstream event"); - + // Handle buffered upstream task from previous iteration + task = async { next_upstream_task.take() }, if next_upstream_task.is_some() => { + debug!("buffered upstream event: {:?}", task); if let Some(t) = task { - debug!("upstream event custom: {:?}", t); - if serve_from_cache.should_discard_upstream() { - // just drain, do we need to do anything else? - continue; - } - // pull as many tasks as we can - let mut tasks = Vec::with_capacity(TASK_BUFFER_SIZE); - tasks.push(t); - while let Ok(task) = rx.try_recv() { - tasks.push(task); - } - - /* run filters before sending to downstream */ - let mut filtered_tasks = Vec::with_capacity(TASK_BUFFER_SIZE); - for mut t in tasks { - if self.revalidate_or_stale(session, &mut t, ctx).await { - serve_from_cache.enable(); - response_state.enable_cached_response(); - // skip downstream filtering entirely as the 304 will not be sent - break; - } - #[cfg(feature = "adjust_upstream_modules")] - if let HttpTask::Header(header, end_of_stream) = &t { - self.inner - .adjust_upstream_modules(session, header, *end_of_stream, ctx) - .await?; - } - session.upstream_compression.response_filter(&mut t); - // check error and abort - // otherwise the error is surfaced via write_response_tasks() - if !serve_from_cache.should_send_to_downstream() { - if let HttpTask::Failed(e) = t { - return Err(e); - } - } - filtered_tasks.push( - self.custom_response_filter(session, t, ctx, - &mut serve_from_cache, - &mut range_body_filter, false).await?); - if serve_from_cache.is_miss_header() { - response_state.enable_cached_response(); - } - } - - if !serve_from_cache.should_send_to_downstream() { - // TODO: need to derive response_done from filtered_tasks in case downstream failed already + let Some(response_done) = self.process_upstream_tasks_custom( + session, + ctx, + t, + &mut rx, + &mut serve_from_cache, + &mut range_body_filter, + &mut response_state, + ).await? else { + // nothing sent downstream e.g. serve_from_cache continue; - } + }; + response_state.maybe_set_upstream_done(response_done); + } else { + debug!("empty upstream event"); + response_state.maybe_set_upstream_done(true); + } + }, + task = rx.recv(), if !response_state.upstream_done() && next_upstream_task.is_none() => { + debug!("upstream event: {:?}", task); + if let Some(t) = task { let upgraded = session.was_upgraded(); - let response_done = session.write_response_tasks(filtered_tasks).await?; + let Some(response_done) = self.process_upstream_tasks_custom( + session, + ctx, + t, + &mut rx, + &mut serve_from_cache, + &mut range_body_filter, + &mut response_state, + ).await? else { + // nothing sent downstream e.g. serve_from_cache + continue; + }; if !upgraded && session.was_upgraded() && downstream_state.can_poll() { // just upgraded, the downstream state should be reset to continue to // poll body trace!("reset downstream state on upgrade"); downstream_state.reset(); } - response_state.maybe_set_upstream_done(response_done); } else { debug!("empty upstream event"); response_state.maybe_set_upstream_done(true); } - } + }, task = serve_from_cache.next_http_task(&mut session.cache, &mut range_body_filter, upgraded), - if !response_state.cached_done() && !downstream_state.is_errored() && serve_from_cache.is_on() => { + if !response_state.cached_done() + && !downstream_state.is_errored() + && serve_from_cache.is_on() + && !session.has_pending_downstream_tasks() => { // backpressure: don't queue if pending writes + let task = self.custom_response_filter(session, task?, ctx, &mut serve_from_cache, &mut range_body_filter, true).await?; - match session.write_response_tasks(vec![task]).await { - Ok(b) => response_state.maybe_set_cache_done(b), - Err(e) => if serve_from_cache.is_miss() { - // give up writing to downstream but wait for upstream cache write to finish - downstream_state.to_errored(); - response_state.maybe_set_cache_done(true); - warn!( - "Downstream Error ignored during caching: {}, {}", - e, - self.inner.request_summary(session, ctx) - ); - continue; - } else { - return Err(e); + + if session.downstream_session.supports_proxy_task_api() { + session.send_downstream_proxy_task(task).await?; + } else { + match session.write_response_tasks(vec![task]).await { + Ok(b) => response_state.maybe_set_cache_done(b), + Err(e) => if serve_from_cache.is_miss() { + // give up writing to downstream but wait for upstream cache write to finish + downstream_state.to_errored(); + response_state.maybe_set_cache_done(true); + warn!( + "Downstream Error ignored during caching: {}, {}", + e, + self.inner.request_summary(session, ctx) + ); + session.downstream_session.on_proxy_failure(e); + continue; + } else { + return Err(e); + } + } + if response_state.cached_done() { + if let Err(e) = session.cache.finish_hit_handler().await { + warn!("Error during finish_hit_handler: {}", e); + } } } - if response_state.cached_done() { - if let Err(e) = session.cache.finish_hit_handler().await { - warn!("Error during finish_hit_handler: {}", e); + } + + // Write queued downstream proxy tasks while also polling for upstream tasks. + // This allows cache writes to continue even when downstream is stalled. + // + // "Gate" branch: ready(()) resolves immediately, so the guard controls + // whether we enter. This is not a busy-loop because every path through + // the inner select either (a) drains all pending tasks via + // write_downstream_proxy_tasks (making the guard false), (b) stores an + // upstream task in next_upstream_task (making the guard false), or + // (c) blocks on real I/O inside the nested select. + _ = std::future::ready(()), if session.has_pending_downstream_tasks() && next_upstream_task.is_none() => { + tokio::select! { + // Try to write downstream proxy tasks (cancel-safe) + write_result = session.write_downstream_proxy_tasks() => { + match write_result { + Ok(end) => { + response_state.maybe_set_cache_done(end); + if response_state.cached_done() { + if let Err(e) = session.cache.finish_hit_handler().await { + warn!("Error during finish_hit_handler: {}", e); + } + } + } + Err(e) => if serve_from_cache.is_miss() { + // give up writing to downstream but wait for upstream cache write to finish + downstream_state.to_errored(); + response_state.maybe_set_cache_done(true); + warn!( + "Downstream write error ignored during caching: {}, {}", + e, + self.inner.request_summary(session, ctx) + ); + session.downstream_session.on_proxy_failure(e); + } else { + return Err(e); + } + } + } + + // Also poll for upstream tasks - if we get one, cancel the write and handle it. + upstream_task = rx.recv(), if !response_state.upstream_done() && serve_from_cache.is_on() && next_upstream_task.is_none() => { + if let Some(t) = upstream_task { + next_upstream_task = Some(t); + continue; + } else { + response_state.maybe_set_upstream_done(true); + } } } } diff --git a/pingora-proxy/src/proxy_h2.rs b/pingora-proxy/src/proxy_h2.rs index 97b4fb64..afe58a0b 100644 --- a/pingora-proxy/src/proxy_h2.rs +++ b/pingora-proxy/src/proxy_h2.rs @@ -265,6 +265,87 @@ where (server_session_reuse, error) } + #[allow(clippy::too_many_arguments)] + async fn process_upstream_tasks_h2( + &self, + session: &mut Session, + ctx: &mut SV::CTX, + initial_task: HttpTask, + rx: &mut mpsc::Receiver, + serve_from_cache: &mut ServeFromCache, + range_body_filter: &mut proxy_cache::range_filter::RangeBodyFilter, + response_state: &mut ResponseStateMachine, + ) -> Result> + where + SV: ProxyHttp + Send + Sync, + SV::CTX: Send + Sync, + { + if serve_from_cache.should_discard_upstream() { + // just drain, do we need to do anything else? + return Ok(None); + } + + // Batch: pull as many tasks as we can from rx + let mut tasks = Vec::with_capacity(TASK_BUFFER_SIZE); + tasks.push(initial_task); + // tokio::task::unconstrained because now_or_never may yield None when the future is ready + while let Some(maybe_task) = tokio::task::unconstrained(rx.recv()).now_or_never() { + if let Some(t) = maybe_task { + tasks.push(t); + } else { + break; // upstream closed + } + } + + /* run filters before sending to downstream */ + let mut filtered_tasks = Vec::with_capacity(TASK_BUFFER_SIZE); + for mut t in tasks { + if self.revalidate_or_stale(session, &mut t, ctx).await { + serve_from_cache.enable(); + response_state.enable_cached_response(); + // skip downstream filtering entirely as the 304 will not be sent + break; + } + #[cfg(feature = "adjust_upstream_modules")] + if let HttpTask::Header(header, end_of_stream) = &t { + self.inner + .adjust_upstream_modules(session, header, *end_of_stream, ctx) + .await?; + } + session.upstream_compression.response_filter(&mut t); + // check error and abort + // otherwise the error is surfaced via write_response_tasks() + if !serve_from_cache.should_send_to_downstream() { + if let HttpTask::Failed(e) = t { + return Err(e); + } + } + filtered_tasks.push( + self.h2_response_filter( + session, + t, + ctx, + serve_from_cache, + range_body_filter, + false, + ) + .await?, + ); + if serve_from_cache.is_miss_header() { + response_state.enable_cached_response(); + } + } + + if !serve_from_cache.should_send_to_downstream() { + // TODO: need to derive response_done from filtered_tasks in case downstream failed already + return Ok(None); + } + + let response_done = session.write_response_tasks(filtered_tasks).await?; + + Ok(Some(response_done)) + } + // returns whether server (downstream) session can be reused async fn bidirection_down_to_up( &self, @@ -322,6 +403,8 @@ where let mut serve_from_cache = ServeFromCache::new(); let mut range_body_filter = proxy_cache::range_filter::RangeBodyFilter::new(); + let mut next_upstream_task: Option = None; + /* duplex mode * see the Same function for h1 for more comments */ @@ -388,64 +471,47 @@ where }; }, - task = rx.recv(), if !response_state.upstream_done() => { + // Handle buffered upstream task from previous iteration + task = async { next_upstream_task.take() }, if next_upstream_task.is_some() => { + debug!("buffered upstream event: {:?}", task); if let Some(t) = task { - debug!("upstream event: {:?}", t); - if serve_from_cache.should_discard_upstream() { - // just drain, do we need to do anything else? - continue; - } - // pull as many tasks as we can - let mut tasks = Vec::with_capacity(TASK_BUFFER_SIZE); - tasks.push(t); - // tokio::task::unconstrained because now_or_never may yield None when the future is ready - while let Some(maybe_task) = tokio::task::unconstrained(rx.recv()).now_or_never() { - if let Some(t) = maybe_task { - tasks.push(t); - } else { - break - } - } - - /* run filters before sending to downstream */ - let mut filtered_tasks = Vec::with_capacity(TASK_BUFFER_SIZE); - for mut t in tasks { - if self.revalidate_or_stale(session, &mut t, ctx).await { - serve_from_cache.enable(); - response_state.enable_cached_response(); - // skip downstream filtering entirely as the 304 will not be sent - break; - } - #[cfg(feature = "adjust_upstream_modules")] - if let HttpTask::Header(header, end_of_stream) = &t { - self.inner - .adjust_upstream_modules(session, header, *end_of_stream, ctx) - .await?; - } - session.upstream_compression.response_filter(&mut t); - // check error and abort - // otherwise the error is surfaced via write_response_tasks() - if !serve_from_cache.should_send_to_downstream() { - if let HttpTask::Failed(e) = t { - return Err(e); - } - } - filtered_tasks.push( - self.h2_response_filter(session, t, ctx, - &mut serve_from_cache, - &mut range_body_filter, false).await?); - if serve_from_cache.is_miss_header() { - response_state.enable_cached_response(); - } - } - - if !serve_from_cache.should_send_to_downstream() { - // TODO: need to derive response_done from filtered_tasks in case downstream failed already + let Some(response_done) = self.process_upstream_tasks_h2( + session, + ctx, + t, + &mut rx, + &mut serve_from_cache, + &mut range_body_filter, + &mut response_state, + ).await? else { + // nothing sent downstream e.g. serve_from_cache continue; + }; + if session.was_upgraded() { + return Error::e_explain(H2Error, "upgraded while proxying to h2 session"); } + response_state.maybe_set_upstream_done(response_done); + } else { + debug!("empty upstream event"); + response_state.maybe_set_upstream_done(true); + } + }, - // TODO: If downstream supports proxy task API, should use send_downstream_proxy_task() - let response_done = session.write_response_tasks(filtered_tasks).await?; + task = rx.recv(), if !response_state.upstream_done() && next_upstream_task.is_none() => { + debug!("upstream event: {:?}", task); + if let Some(t) = task { + let Some(response_done) = self.process_upstream_tasks_h2( + session, + ctx, + t, + &mut rx, + &mut serve_from_cache, + &mut range_body_filter, + &mut response_state, + ).await? else { + // nothing sent downstream e.g. serve_from_cache + continue; + }; if session.was_upgraded() { // it is very weird if the downstream session decides to upgrade // since the client h2 session cannot, return an error on this case @@ -456,38 +522,95 @@ where debug!("empty upstream event"); response_state.maybe_set_upstream_done(true); } - } + }, task = serve_from_cache.next_http_task(&mut session.cache, &mut range_body_filter, upgraded), - if !response_state.cached_done() && !downstream_state.is_errored() && serve_from_cache.is_on() => { + if !response_state.cached_done() + && !downstream_state.is_errored() + && serve_from_cache.is_on() + && !session.has_pending_downstream_tasks() => { // backpressure: don't queue if pending writes + let task = self.h2_response_filter(session, task?, ctx, &mut serve_from_cache, &mut range_body_filter, true).await?; debug!("serve_from_cache task {task:?}"); - // TODO: If downstream supports proxy task API, should use send_downstream_proxy_task() - match session.write_response_tasks(vec![task]).await { - Ok(b) => response_state.maybe_set_cache_done(b), - Err(e) => if serve_from_cache.is_miss() { - // give up writing to downstream but wait for upstream cache write to finish - downstream_state.to_errored(); - response_state.maybe_set_cache_done(true); - warn!( - "Downstream Error ignored during caching: {}, {}", - e, - self.inner.request_summary(session, ctx) - ); - // This will not be treated as a final error, but we should signal to - // downstream session regardless - session.downstream_session.on_proxy_failure(e); - continue; - } else { - return Err(e); + if session.downstream_session.supports_proxy_task_api() { + session.send_downstream_proxy_task(task).await?; + } else { + match session.write_response_tasks(vec![task]).await { + Ok(b) => response_state.maybe_set_cache_done(b), + Err(e) => if serve_from_cache.is_miss() { + // give up writing to downstream but wait for upstream cache write to finish + downstream_state.to_errored(); + response_state.maybe_set_cache_done(true); + warn!( + "Downstream Error ignored during caching: {}, {}", + e, + self.inner.request_summary(session, ctx) + ); + // This will not be treated as a final error, but we should signal to + // downstream session regardless + session.downstream_session.on_proxy_failure(e); + continue; + } else { + return Err(e); + } + } + if response_state.cached_done() { + if let Err(e) = session.cache.finish_hit_handler().await { + warn!("Error during finish_hit_handler: {}", e); + } } } - if response_state.cached_done() { - if let Err(e) = session.cache.finish_hit_handler().await { - warn!("Error during finish_hit_handler: {}", e); + } + + // Write queued downstream proxy tasks while also polling for upstream tasks. + // This allows cache writes to continue even when downstream is stalled. + // + // "Gate" branch: ready(()) resolves immediately, so the guard controls + // whether we enter. This is not a busy-loop because every path through + // the inner select either (a) drains all pending tasks via + // write_downstream_proxy_tasks (making the guard false), (b) stores an + // upstream task in next_upstream_task (making the guard false), or + // (c) blocks on real I/O inside the nested select. + _ = std::future::ready(()), if session.has_pending_downstream_tasks() && next_upstream_task.is_none() => { + tokio::select! { + // Try to write downstream proxy tasks (cancel-safe) + write_result = session.write_downstream_proxy_tasks() => { + match write_result { + Ok(end) => { + response_state.maybe_set_cache_done(end); + if response_state.cached_done() { + if let Err(e) = session.cache.finish_hit_handler().await { + warn!("Error during finish_hit_handler: {}", e); + } + } + } + Err(e) => if serve_from_cache.is_miss() { + // give up writing to downstream but wait for upstream cache write to finish + downstream_state.to_errored(); + response_state.maybe_set_cache_done(true); + warn!( + "Downstream write error ignored during caching: {}, {}", + e, + self.inner.request_summary(session, ctx) + ); + session.downstream_session.on_proxy_failure(e); + } else { + return Err(e); + } + } + } + + // Also poll for upstream tasks - if we get one, cancel the write and handle it. + upstream_task = rx.recv(), if !response_state.upstream_done() && serve_from_cache.is_on() && next_upstream_task.is_none() => { + if let Some(t) = upstream_task { + next_upstream_task = Some(t); + continue; + } else { + response_state.maybe_set_upstream_done(true); + } } } } diff --git a/pingora-proxy/tests/test_upstream.rs b/pingora-proxy/tests/test_upstream.rs index eeafcda9..cdba09b7 100644 --- a/pingora-proxy/tests/test_upstream.rs +++ b/pingora-proxy/tests/test_upstream.rs @@ -2996,9 +2996,7 @@ mod test_cache { // Same as test_caching_when_downstream_stalls but the proxy connects // to the origin over H2 (via the x-h2 header). // - // Ignored until proxy_h2 gets the proxy task API. #[tokio::test] - #[ignore] async fn test_caching_h2_upstream_when_downstream_stalls() { use std::net::ToSocketAddrs; use tokio::io::{AsyncReadExt, AsyncWriteExt}; From e7de90a7a62eca781ac1be8916ae859dcd9632a5 Mon Sep 17 00:00:00 2001 From: Edward Wang Date: Fri, 27 Mar 2026 16:49:24 -0700 Subject: [PATCH 19/52] Fix body bytes count on v1 session This was previously counting the response header bytes as well, which is incorrect. --- .bleep | 2 +- pingora-core/src/protocols/http/v1/server.rs | 25 +++++++++++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/.bleep b/.bleep index 28024185..bea6846e 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -855ad50aae6d7bc16238e8dbd1a21fc9dcc5cab9 \ No newline at end of file +7d6e87142a10ef59fd622cb2acc48bf331185b4d \ No newline at end of file diff --git a/pingora-core/src/protocols/http/v1/server.rs b/pingora-core/src/protocols/http/v1/server.rs index f4465265..9144c6e5 100644 --- a/pingora-core/src/protocols/http/v1/server.rs +++ b/pingora-core/src/protocols/http/v1/server.rs @@ -566,7 +566,6 @@ impl HttpSession { .or_err(WriteError, "flushing response header")?; } self.response_written = Some(header); - self.body_bytes_sent += write_buf.len(); Ok(()) } Err(e) => Error::e_because(WriteError, "writing response header", e), @@ -2599,6 +2598,30 @@ mod tests_stream { assert_eq!(wire_body.len(), n); } + #[tokio::test] + async fn body_bytes_sent_excludes_response_header() { + let read_wire = b"GET / HTTP/1.1\r\n\r\n"; + let wire_header = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\n"; + let wire_body = b"hello"; + let mock_io = Builder::new() + .read(read_wire) + .write(wire_header) + .write(wire_body) + .build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); + let mut new_response = ResponseHeader::build(StatusCode::OK, None).unwrap(); + new_response.append_header("Content-Length", "5").unwrap(); + http_stream.update_resp_headers = false; + http_stream + .write_response_header(Box::new(new_response)) + .await + .unwrap(); + assert_eq!(http_stream.body_bytes_sent(), 0); + http_stream.write_body(wire_body).await.unwrap(); + assert_eq!(http_stream.body_bytes_sent(), wire_body.len()); + } + #[tokio::test] async fn write_body_http10() { let read_wire = b"GET / HTTP/1.1\r\n\r\n"; From 9267745ba11046da5bb59b83dd17746ac5f91aa5 Mon Sep 17 00:00:00 2001 From: Fei Deng Date: Thu, 2 Apr 2026 13:00:21 -0400 Subject: [PATCH 20/52] add peek_lru to LRU eviction manager MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add LruUnit::peek_lru(), Lru::peek_lru(shard), and Manager::peek_lru(shard) to peek at the least-recently-used item in a shard without evicting it. Returns None for empty shards or out-of-bounds shard indices. This enables callers to report the eviction frontier — the age of the item that would be evicted next — for cache observability metrics. --- .bleep | 2 +- pingora-cache/src/eviction/lru.rs | 47 ++++++++++++++++++ pingora-lru/src/lib.rs | 80 +++++++++++++++++++++++++++++++ 3 files changed, 128 insertions(+), 1 deletion(-) diff --git a/.bleep b/.bleep index bea6846e..e910494e 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -7d6e87142a10ef59fd622cb2acc48bf331185b4d \ No newline at end of file +d2681a9a12ffd53d2a97e2f528e28e443dbb8318 \ No newline at end of file diff --git a/pingora-cache/src/eviction/lru.rs b/pingora-cache/src/eviction/lru.rs index d241ee69..96285700 100644 --- a/pingora-cache/src/eviction/lru.rs +++ b/pingora-cache/src/eviction/lru.rs @@ -85,6 +85,15 @@ impl Manager { (u64key(key) % N as u64) as usize } + /// Peek at the least-recently-used key in the given shard without evicting it. + /// + /// Returns the cache key at the LRU tail of the shard, or `None` if empty. + /// Useful for reporting the eviction frontier (the age of the next item + /// that would be evicted). + pub fn peek_lru(&self, shard: usize) -> Option { + self.0.peek_lru(shard).map(|(key, _weight)| key) + } + /// Serialize the given shard pub fn serialize_shard(&self, shard: usize) -> Result> { use rmp_serde::encode::Serializer; @@ -614,4 +623,42 @@ mod test { // Cleanup test directory std::fs::remove_dir_all(dir_path).unwrap(); } + + #[test] + fn test_peek_lru() { + let lru = Manager::<1>::with_capacity(20, 20); + let until = SystemTime::now(); + + // empty shard returns None + assert!(lru.peek_lru(0).is_none()); + + let key1 = CacheKey::new("", "a", "1").to_compact(); + lru.admit(key1.clone(), 1, until); + // single item: it's both the head and the tail + assert_eq!(lru.peek_lru(0).unwrap(), key1); + + // admit more keys to push key1 to the tail + let key2 = CacheKey::new("", "b", "1").to_compact(); + lru.admit(key2.clone(), 1, until); + for i in 0..5 { + lru.admit( + CacheKey::new("", format!("f{i}"), "1").to_compact(), + 1, + until, + ); + } + // key1 is the LRU tail (admitted first) + assert_eq!(lru.peek_lru(0).unwrap(), key1); + + // promote key1 — now key2 becomes the tail + lru.access(&key1, 1, until); + assert_eq!(lru.peek_lru(0).unwrap(), key2); + + // peek_lru should not remove the item + assert_eq!(lru.peek_lru(0).unwrap(), key2); + assert!(lru.peek(&key2)); + + // out-of-bounds shard returns None + assert!(lru.peek_lru(999).is_none()); + } } diff --git a/pingora-lru/src/lib.rs b/pingora-lru/src/lib.rs index 23728c4f..af0b2d91 100644 --- a/pingora-lru/src/lib.rs +++ b/pingora-lru/src/lib.rs @@ -226,6 +226,21 @@ impl Lru { self.units[get_shard(key, N)].read().peek_weight(key) } + /// Peek at the least-recently-used item in the given shard without removing it. + /// + /// Returns a clone of the data and the weight, or `None` if the shard is empty + /// or `shard >= N`. + pub fn peek_lru(&self, shard: usize) -> Option<(T, usize)> + where + T: Clone, + { + self.units + .get(shard)? + .read() + .peek_lru() + .map(|(data, weight)| (data.clone(), weight)) + } + /// Return the current total weight. pub fn weight(&self) -> usize { self.weight.load(Ordering::Relaxed) @@ -374,6 +389,19 @@ impl LruUnit { (node.data, node.weight) }) } + + /// Peek at the least-recently-used item without removing it. + /// + /// Returns a reference to the data and weight of the tail item, or `None` + /// if empty. + pub fn peek_lru(&self) -> Option<(&T, usize)> { + self.order + .tail() + .and_then(|idx| self.order.peek(idx)) + .and_then(|key| self.lookup_table.get(&key)) + .map(|node| (&node.data, node.weight)) + } + // TODO: scan the tail up to K elements to decide which ones to evict pub fn remove(&mut self, key: u64) -> Option<(T, usize)> { @@ -696,6 +724,29 @@ mod test_lru { assert_eq!(evicted.len(), 2); assert_eq!(lru.evicted_len(), 2); } + + #[test] + fn test_peek_lru() { + let lru = Lru::::with_capacity(10, 10); + + // empty shard + assert!(lru.peek_lru(0).is_none()); + + lru.admit(1, 10, 1); + assert_eq!(lru.peek_lru(0).unwrap(), (10, 1)); + + lru.admit(2, 20, 2); + // key 1 is LRU tail + assert_eq!(lru.peek_lru(0).unwrap(), (10, 1)); + + // promote key 1 + lru.promote(1); + // key 2 is now LRU tail + assert_eq!(lru.peek_lru(0).unwrap(), (20, 2)); + + // out-of-bounds returns None + assert!(lru.peek_lru(999).is_none()); + } } #[cfg(test)] @@ -865,4 +916,33 @@ mod test_lru_unit { assert_eq!(lru.used_weight(), 1 + 3 + 4 + 5); assert_lru(&lru, &[2, 3, 4, 5]); } + + #[test] + fn test_peek_lru() { + let mut lru = LruUnit::with_capacity(10); + + // empty returns None + assert!(lru.peek_lru().is_none()); + + // single item is both head and tail + lru.admit(1, 10, 1); + let (data, weight) = lru.peek_lru().unwrap(); + assert_eq!(*data, 10); + assert_eq!(weight, 1); + + // second admission pushes first to tail + lru.admit(2, 20, 2); + let (data, _) = lru.peek_lru().unwrap(); + assert_eq!(*data, 10); // key 1 is LRU tail + + // promote key 1 — now key 2 is tail + lru.access(1); + let (data, _) = lru.peek_lru().unwrap(); + assert_eq!(*data, 20); // key 2 is now LRU tail + + // peek doesn't remove + assert!(lru.peek_lru().is_some()); + assert!(lru.peek(1).is_some()); + assert!(lru.peek(2).is_some()); + } } From ea9d9ec81a166c336c29f072857e8d49da84a353 Mon Sep 17 00:00:00 2001 From: Davis To Date: Fri, 20 Mar 2026 13:55:56 -0700 Subject: [PATCH 21/52] Expose Unexpected Data Counter from Connection Pool --- .bleep | 2 +- pingora-core/src/connectors/http/mod.rs | 13 ++++++ pingora-core/src/connectors/http/v1.rs | 13 ++++++ pingora-core/src/connectors/mod.rs | 62 +++++++++++++++++++++++-- pingora-proxy/src/lib.rs | 13 +++++- 5 files changed, 97 insertions(+), 6 deletions(-) diff --git a/.bleep b/.bleep index e910494e..6436301c 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -d2681a9a12ffd53d2a97e2f528e28e443dbb8318 \ No newline at end of file +ddc5c39c76ea10f1bc83dbe58889e937031873e4 \ No newline at end of file diff --git a/pingora-core/src/connectors/http/mod.rs b/pingora-core/src/connectors/http/mod.rs index 2545cf7c..5a671ef2 100644 --- a/pingora-core/src/connectors/http/mod.rs +++ b/pingora-core/src/connectors/http/mod.rs @@ -21,6 +21,8 @@ use crate::protocols::http::client::HttpSession; use crate::protocols::http::v1::client::HttpSession as Http1Session; use crate::upstreams::peer::Peer; use pingora_error::Result; +use std::sync::atomic::AtomicU64; +use std::sync::Arc; use std::time::Duration; pub mod custom; @@ -151,6 +153,17 @@ where pub fn prefer_h1(&self, peer: &impl Peer) { self.h2.prefer_h1(peer); } + + /// Return the number of times a pooled connection was found to contain + /// unexpected data from the server. + pub fn unexpected_data_connection_count(&self) -> u64 { + self.h1.unexpected_data_connection_count() + } + + /// Return a shared reference to the unexpected data connection counter for periodic metric reporting. + pub fn unexpected_data_connection_counter(&self) -> Arc { + self.h1.unexpected_data_connection_counter() + } } #[cfg(test)] diff --git a/pingora-core/src/connectors/http/v1.rs b/pingora-core/src/connectors/http/v1.rs index 62ecfcb6..ab04b2f6 100644 --- a/pingora-core/src/connectors/http/v1.rs +++ b/pingora-core/src/connectors/http/v1.rs @@ -17,6 +17,8 @@ use crate::protocols::http::v1::client::HttpSession; use crate::upstreams::peer::Peer; use pingora_error::Result; +use std::sync::atomic::AtomicU64; +use std::sync::Arc; use std::time::Duration; pub struct Connector { @@ -60,6 +62,17 @@ impl Connector { .release_stream(stream, peer.reuse_hash(), idle_timeout); } } + + /// Return the number of times a pooled connection was found to contain + /// unexpected data from the server. + pub fn unexpected_data_connection_count(&self) -> u64 { + self.transport.unexpected_data_connection_count() + } + + /// Return a shared reference to the unexpected data connection counter for periodic metric reporting. + pub fn unexpected_data_connection_counter(&self) -> Arc { + self.transport.unexpected_data_connection_counter() + } } #[cfg(test)] diff --git a/pingora-core/src/connectors/mod.rs b/pingora-core/src/connectors/mod.rs index e5e987cb..0e3c727c 100644 --- a/pingora-core/src/connectors/mod.rs +++ b/pingora-core/src/connectors/mod.rs @@ -37,6 +37,7 @@ use pingora_error::{Error, ErrorType::*, OrErr, Result}; use pingora_pool::{ConnectionMeta, ConnectionPool}; use std::collections::HashMap; use std::net::SocketAddr; +use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use tls::TlsConnector; use tokio::sync::Mutex; @@ -146,6 +147,9 @@ pub struct TransportConnector { bind_to_v4: Vec, bind_to_v6: Vec, preferred_http_version: PreferredHttpVersion, + /// Wrapped in `Arc` so external consumers (e.g. proxy services) can clone a reference + /// for periodic metric reporting without needing access to the connector itself. + unexpected_data_conn_count: Arc, } const DEFAULT_POOL_SIZE: usize = 128; @@ -172,6 +176,7 @@ impl TransportConnector { bind_to_v4, bind_to_v6, preferred_http_version: PreferredHttpVersion::new(), + unexpected_data_conn_count: Arc::new(AtomicU64::new(0)), } } @@ -212,7 +217,9 @@ impl TransportConnector { // test_reusable_stream: we assume server would never actively send data // first on an idle stream. #[cfg(unix)] - if peer.matches_fd(stream.id()) && test_reusable_stream(&mut stream) { + if peer.matches_fd(stream.id()) + && test_reusable_stream(&mut stream, &self.unexpected_data_conn_count) + { Some(stream) } else { None @@ -227,7 +234,10 @@ impl TransportConnector { } } if peer.matches_sock(WrappedRawSocket(stream.id() as RawSocket)) - && test_reusable_stream(&mut stream) + && test_reusable_stream( + &mut stream, + &self.unexpected_data_conn_count, + ) { Some(stream) } else { @@ -261,7 +271,7 @@ impl TransportConnector { key: u64, // usually peer.reuse_hash() idle_timeout: Option, ) { - if !test_reusable_stream(&mut stream) { + if !test_reusable_stream(&mut stream, &self.unexpected_data_conn_count) { return; } let id = stream.id(); @@ -301,6 +311,21 @@ impl TransportConnector { pub fn prefer_h1(&self, peer: &impl Peer) { self.preferred_http_version.add(peer, 1); } + + /// Return the number of times a pooled connection was found to contain unexpected data + /// from the server. + pub fn unexpected_data_connection_count(&self) -> u64 { + self.unexpected_data_conn_count.load(Ordering::Relaxed) + } + + /// Return a shared reference to the unexpected data connection counter. + /// + /// This allows external consumers (e.g. proxy services) to clone the `Arc` and + /// periodically read the counter for metric reporting without needing ongoing + /// access to the connector. + pub fn unexpected_data_connection_counter(&self) -> Arc { + self.unexpected_data_conn_count.clone() + } } // Perform the actual L4 and tls connection steps while respecting the peer's @@ -376,7 +401,7 @@ use futures::future::FutureExt; use tokio::io::AsyncReadExt; /// Test whether a stream is already closed or not reusable (server sent unexpected data) -fn test_reusable_stream(stream: &mut Stream) -> bool { +fn test_reusable_stream(stream: &mut Stream, unexpected_data_conn_count: &AtomicU64) -> bool { let mut buf = [0; 1]; // tokio::task::unconstrained because now_or_never may yield None when the future is ready let result = tokio::task::unconstrained(stream.read(&mut buf[..])).now_or_never(); @@ -387,6 +412,7 @@ fn test_reusable_stream(stream: &mut Stream) -> bool { debug!("Idle connection is closed"); } else { warn!("Unexpected data read in idle connection"); + unexpected_data_conn_count.fetch_add(1, Ordering::Relaxed); } } Err(e) => { @@ -644,4 +670,32 @@ mod tests { let (etype, context) = get_do_connect_failure_with_peer(&peer).await; assert!(etype != ConnectTimedout || !context.contains("total-connection timeout")); } + + #[tokio::test] + async fn test_unexpected_data_connection_count_increments() { + // Create a duplex stream where we control both ends + let (mut server, client) = tokio::io::duplex(64); + + let counter = AtomicU64::new(0); + let mut stream: Stream = Box::new(client); + + // With no data available, the stream should be considered reusable + assert!(test_reusable_stream(&mut stream, &counter)); + assert_eq!(counter.load(Ordering::Relaxed), 0); + + // Write unexpected data from the server side + use tokio::io::AsyncWriteExt; + server.write_all(b"unexpected").await.unwrap(); + + // Give the data a moment to be buffered + tokio::task::yield_now().await; + + // Now test_reusable_stream should detect the unexpected data + assert!(!test_reusable_stream(&mut stream, &counter)); + assert_eq!( + counter.load(Ordering::Relaxed), + 1, + "unexpected_data_connection_count should have incremented" + ); + } } diff --git a/pingora-proxy/src/lib.rs b/pingora-proxy/src/lib.rs index 3a50f704..e5433efa 100644 --- a/pingora-proxy/src/lib.rs +++ b/pingora-proxy/src/lib.rs @@ -46,7 +46,7 @@ use pingora_http::{RequestHeader, ResponseHeader}; use std::fmt::Debug; use std::str; use std::sync::{ - atomic::{AtomicBool, Ordering}, + atomic::{AtomicBool, AtomicU64, Ordering}, Arc, }; use std::time::Duration; @@ -190,6 +190,17 @@ where } } + /// Return the number of times a pooled upstream connection was found to contain + /// unexpected data from the server. + pub fn unexpected_data_connection_count(&self) -> u64 { + self.client_upstream.unexpected_data_connection_count() + } + + /// Return a shared reference to the unexpected data connection counter for periodic metric reporting. + pub fn unexpected_data_connection_counter(&self) -> Arc { + self.client_upstream.unexpected_data_connection_counter() + } + /// Initialize the downstream modules for this proxy. /// /// This method must be called after creating an [`HttpProxy`] with [`HttpProxy::new()`] From d41a66b4f6268a99540b014c7e833e4721c32743 Mon Sep 17 00:00:00 2001 From: Fei Deng Date: Thu, 2 Apr 2026 13:14:24 -0400 Subject: [PATCH 22/52] update bench_lru with production-scale data, warn about promote_top_n Update bench_lru to test at production-level data sizes (~100K and ~500K items/shard). The original benchmark only tested 100 items across 10 shards (10 per shard), which made promote_top_n appear 42% faster. At larger scales, promote() is actually 20-25% faster because the read-lock scan rarely finds hot items near the head. Add heavy-hitter benchmarks (10 and 100 items at 10,000x weight) to test whether extremely concentrated access patterns benefit from promote_top_n. Result: promote() still ties or wins even with heavy hitters, because with few hot items spread across 32 shards, most shards have 0-1 hot items and the scan is wasted on cold accesses. Each benchmark variant uses a fresh LRU and a thread barrier to avoid state contamination and staggered starts. The 16M-item config is gated behind BENCH_LARGE=1 to avoid OOM on CI. Add a performance warning to promote_top_n() docs recommending promote() for large-scale workloads. --- .bleep | 2 +- pingora-lru/benches/bench_lru.rs | 316 ++++++++++++++++++++----------- pingora-lru/src/lib.rs | 15 +- 3 files changed, 221 insertions(+), 112 deletions(-) diff --git a/.bleep b/.bleep index 6436301c..5269af84 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -ddc5c39c76ea10f1bc83dbe58889e937031873e4 \ No newline at end of file +85c78ad06e98d7e93900693ce5135f54d2ee3341 \ No newline at end of file diff --git a/pingora-lru/benches/bench_lru.rs b/pingora-lru/benches/bench_lru.rs index c0bdc776..02dadec8 100644 --- a/pingora-lru/benches/bench_lru.rs +++ b/pingora-lru/benches/bench_lru.rs @@ -12,137 +12,237 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Benchmark for `Lru::promote()` vs `Lru::promote_top_n()`. +//! +//! Tests both small (original) and production-scale LRU sizes to show how +//! the `promote_top_n` optimization behaves at different scales. +//! +//! Run with: `cargo bench -p pingora-lru --bench bench_lru` +//! +//! ## Results (Apple M3 Max, 2026-04-03) +//! +//! Benchmark tiers simulate production-level data sizes (items/shard +//! ranging from ~100 to ~500K) with both uniform-hot and heavy-hitter +//! access patterns. +//! +//! ### 8-threaded — uniform hot set (10% of items are 100x hotter) +//! +//! | Items/shard | promote | top_n(0) | top_n(3) | top_n(10) | top_n(50) | top_n(100) | +//! |-------------|------------|----------|----------|-----------|-----------|------------| +//! | 10 (orig) | 366ns | 476ns | 271ns | **164ns** | 164ns | 164ns | +//! | 100K (typ) | **457ns** | 480ns | 437ns | 520ns | 1227ns | 2394ns | +//! +//! ### 8-threaded — heavy hitters (10 or 100 items are 10,000x hotter) +//! +//! | Items/shard | promote | top_n(0) | top_n(3) | top_n(10) | top_n(50) | top_n(100) | +//! |-----------------|------------|----------|----------|-----------|-----------|------------| +//! | 100K, 10 hot | **649ns** | 688ns | 652ns | 773ns | 1811ns | 3534ns | +//! | 100K, 100 hot | **607ns** | 632ns | 607ns | 716ns | 1493ns | 2759ns | +//! +//! ### Single-threaded — uniform hot set (10% of items are 100x hotter) +//! +//! | Items/shard | promote | top_n(0) | top_n(3) | top_n(10) | top_n(50) | top_n(100) | +//! |-------------|------------|----------|----------|-----------|-----------|------------| +//! | 10 (orig) | 22ns | 20ns | 29ns | **30ns** | 30ns | 30ns | +//! | 100K (typ) | **297ns** | 306ns | 314ns | 332ns | 663ns | 1092ns | +//! +//! **Conclusions**: +//! +//! - `promote_top_n(0)` is strictly worse than `promote()` — it takes a +//! wasted read lock before falling through to the write lock every time. +//! +//! - `promote_top_n(n)` for n > 0 only wins at the original small scale +//! (10 items/shard) where the threshold covers the entire shard. +//! +//! - Even with heavy-hitter patterns (10 items at 10,000x weight), +//! `promote()` ties or wins at production scale. With 10 hot items +//! across 32 shards, most shards have 0-1 hot items, so the read-lock +//! scan is wasted on the majority of cold-item accesses. +//! +//! - At production scale (~100K+ items/shard), plain `promote()` is fastest +//! regardless of access pattern. + use rand::distributions::WeightedIndex; use rand::prelude::*; -use std::sync::Arc; +use std::sync::{Arc, Barrier}; use std::thread; use std::time::Instant; -// Non-uniform distributions, 100 items, 10 of them are 100x more likely to appear -const WEIGHTS: &[usize] = &[ - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 100, 100, 100, - 100, 100, 100, 100, 100, 100, 100, -]; - const ITERATIONS: usize = 5_000_000; const THREADS: usize = 8; -fn main() { - let lru = parking_lot::Mutex::new(lru::LruCache::::unbounded()); - - let plru = pingora_lru::Lru::<(), 10>::with_capacity(1000, 100); - // populate first, then we bench access/promotion - for i in 0..WEIGHTS.len() { - lru.lock().put(i as u64, ()); - } - for i in 0..WEIGHTS.len() { - plru.admit(i as u64, (), 1); +/// Build a weight distribution where the first `hot_count` items have +/// `hot_weight`x the access probability. +fn make_weights(n: usize, hot_count: usize, hot_weight: usize) -> Vec { + let mut weights = vec![1usize; n]; + for w in weights.iter_mut().take(hot_count) { + *w = hot_weight; } + weights +} - // single thread - let mut rng = thread_rng(); - let dist = WeightedIndex::new(WEIGHTS).unwrap(); +fn bench_config(label: &str, items: usize, shards: usize, hot_pct: usize, hot_weight: usize) { + let hot_count = items * hot_pct / 100; + bench_config_abs(label, items, shards, hot_count, hot_weight); +} - let before = Instant::now(); - for _ in 0..ITERATIONS { - lru.lock().get(&(dist.sample(&mut rng) as u64)); - } - let elapsed = before.elapsed(); +fn bench_config_abs(label: &str, items: usize, shards: usize, hot_count: usize, hot_weight: usize) { println!( - "lru promote total {elapsed:?}, {:?} avg per operation", - elapsed / ITERATIONS as u32 + "\n=== {label}: {items} items, {shards} shards ({} per shard), \ + {hot_count} items are {hot_weight}x hotter ===", + items / shards ); - let before = Instant::now(); - for _ in 0..ITERATIONS { - plru.promote(dist.sample(&mut rng) as u64); - } - let elapsed = before.elapsed(); - println!( - "pingora lru promote total {elapsed:?}, {:?} avg per operation", - elapsed / ITERATIONS as u32 - ); + let weights = make_weights(items, hot_count, hot_weight); + let dist = Arc::new(WeightedIndex::new(&weights).unwrap()); - let before = Instant::now(); - for _ in 0..ITERATIONS { - plru.promote_top_n(dist.sample(&mut rng) as u64, 10); + match shards { + 10 => bench_shards::<10>(items, &dist), + 32 => bench_shards::<32>(items, &dist), + _ => panic!("unsupported shard count: {shards}"), } - let elapsed = before.elapsed(); - println!( - "pingora lru promote_top_10 total {elapsed:?}, {:?} avg per operation", - elapsed / ITERATIONS as u32 - ); +} - // concurrent - - let lru = Arc::new(lru); - let mut handlers = vec![]; - for i in 0..THREADS { - let lru = lru.clone(); - let handler = thread::spawn(move || { - let mut rng = thread_rng(); - let dist = WeightedIndex::new(WEIGHTS).unwrap(); - let before = Instant::now(); - for _ in 0..ITERATIONS { - lru.lock().get(&(dist.sample(&mut rng) as u64)); - } - let elapsed = before.elapsed(); - println!( - "lru promote total {elapsed:?}, {:?} avg per operation thread {i}", - elapsed / ITERATIONS as u32 - ); - }); - handlers.push(handler); +/// Populate a fresh LRU with `items` entries. +fn make_lru(items: usize) -> pingora_lru::Lru<(), N> { + let lru = pingora_lru::Lru::<(), N>::with_capacity(items, items / N); + for i in 0..items { + lru.admit(i as u64, (), 1); } - for thread in handlers { - thread.join().unwrap(); + lru +} + +fn bench_shards(items: usize, dist: &Arc>) { + // Each variant gets a fresh LRU to avoid state contamination from + // prior runs warming hot items to the head. + + // --- Single-threaded --- + println!(" Single-threaded:"); + { + let lru = make_lru::(items); + let mut rng = thread_rng(); + let before = Instant::now(); + for _ in 0..ITERATIONS { + lru.promote(dist.sample(&mut rng) as u64); + } + let elapsed = before.elapsed(); + println!( + " promote: {elapsed:?} total, {:?} avg", + elapsed / ITERATIONS as u32, + ); } - let plru = Arc::new(plru); - - let mut handlers = vec![]; - for i in 0..THREADS { - let plru = plru.clone(); - let handler = thread::spawn(move || { - let mut rng = thread_rng(); - let dist = WeightedIndex::new(WEIGHTS).unwrap(); - let before = Instant::now(); - for _ in 0..ITERATIONS { - plru.promote(dist.sample(&mut rng) as u64); - } - let elapsed = before.elapsed(); - println!( - "pingora lru promote total {elapsed:?}, {:?} avg per operation thread {i}", - elapsed / ITERATIONS as u32 - ); - }); - handlers.push(handler); + for top_n in [0, 3, 10, 50, 100] { + let lru = make_lru::(items); + let mut rng = thread_rng(); + let before = Instant::now(); + for _ in 0..ITERATIONS { + lru.promote_top_n(dist.sample(&mut rng) as u64, top_n); + } + let elapsed = before.elapsed(); + println!( + " promote_top_{top_n:<3} {elapsed:?} total, {:?} avg", + elapsed / ITERATIONS as u32, + ); } - for thread in handlers { - thread.join().unwrap(); + + // --- Multi-threaded --- + println!(" {THREADS}-threaded:"); + + { + let lru = Arc::new(make_lru::(items)); + let barrier = Arc::new(Barrier::new(THREADS)); + let mut handlers = vec![]; + for _ in 0..THREADS { + let lru = lru.clone(); + let dist = Arc::clone(dist); + let barrier = barrier.clone(); + handlers.push(thread::spawn(move || { + let mut rng = thread_rng(); + barrier.wait(); + let before = Instant::now(); + for _ in 0..ITERATIONS { + lru.promote(dist.sample(&mut rng) as u64); + } + before.elapsed() + })); + } + let elapsed: Vec<_> = handlers.into_iter().map(|h| h.join().unwrap()).collect(); + let avg = elapsed.iter().sum::() / THREADS as u32; + println!( + " promote: avg {avg:?}, {:?} avg per op", + avg / ITERATIONS as u32, + ); } - let mut handlers = vec![]; - for i in 0..THREADS { - let plru = plru.clone(); - let handler = thread::spawn(move || { - let mut rng = thread_rng(); - let dist = WeightedIndex::new(WEIGHTS).unwrap(); - let before = Instant::now(); - for _ in 0..ITERATIONS { - plru.promote_top_n(dist.sample(&mut rng) as u64, 10); - } - let elapsed = before.elapsed(); - println!( - "pingora lru promote_top_10 total {elapsed:?}, {:?} avg per operation thread {i}", - elapsed / ITERATIONS as u32 - ); - }); - handlers.push(handler); + for top_n in [0, 3, 10, 50, 100] { + let lru = Arc::new(make_lru::(items)); + let barrier = Arc::new(Barrier::new(THREADS)); + let mut handlers = vec![]; + for _ in 0..THREADS { + let lru = lru.clone(); + let dist = Arc::clone(dist); + let barrier = barrier.clone(); + handlers.push(thread::spawn(move || { + let mut rng = thread_rng(); + barrier.wait(); + let before = Instant::now(); + for _ in 0..ITERATIONS { + lru.promote_top_n(dist.sample(&mut rng) as u64, top_n); + } + before.elapsed() + })); + } + let elapsed: Vec<_> = handlers.into_iter().map(|h| h.join().unwrap()).collect(); + let avg = elapsed.iter().sum::() / THREADS as u32; + println!( + " promote_top_{top_n:<3} avg {avg:?}, {:?} avg per op", + avg / ITERATIONS as u32, + ); } - for thread in handlers { - thread.join().unwrap(); +} + +fn main() { + // Benchmark tiers to simulate production-level data sizes: + // Small = original bench scale (10 items/shard) + // Typical = ~100K items/shard (3.2M total across 32 shards) + // Large = ~500K items/shard (16M total) — gated behind + // BENCH_LARGE=1 to avoid OOM on CI runners (~1.5GB heap) + // + // Note: the Typical tier allocates ~150MB per make_lru() call. With + // multiple variants (promote + 5 top_n values) × configs, total peak + // memory is ~1GB. Well within CI limits but notable for constrained machines. + + // Original benchmark scale (100 items, 10 shards = 10 per shard) + bench_config("Small (original bench scale)", 100, 10, 10, 100); + + // Typical (~100K items/shard), 10% hot + bench_config("Typical (100K/shard, 10% hot)", 3_200_000, 32, 10, 100); + + // Typical (~100K items/shard), heavy-hitter: only 10 items dominate + // Simulates viral content / popular API endpoints where a handful of + // assets receive the vast majority of traffic. + bench_config_abs( + "Typical (100K/shard, 10 heavy hitters)", + 3_200_000, + 32, + 10, + 10_000, + ); + + // Typical (~100K items/shard), moderate hot set: 100 items dominate + bench_config_abs( + "Typical (100K/shard, 100 heavy hitters)", + 3_200_000, + 32, + 100, + 10_000, + ); + + // Large (~500K items/shard, ~1.5GB heap) + if std::env::var("BENCH_LARGE").is_ok() { + bench_config("Large (500K/shard, 10% hot)", 16_000_000, 32, 10, 100); + } else { + println!("\n=== Skipping large bench (set BENCH_LARGE=1 to enable) ==="); } } diff --git a/pingora-lru/src/lib.rs b/pingora-lru/src/lib.rs index af0b2d91..67f59230 100644 --- a/pingora-lru/src/lib.rs +++ b/pingora-lru/src/lib.rs @@ -128,9 +128,18 @@ impl Lru { /// Promote to the top n of the LRU /// - /// This function is a bit more efficient in terms of reducing lock contention because it - /// will acquire a write lock only if the key is outside top n but only acquires a read lock - /// when the key is already in the top n. + /// This function acquires a read lock first to check if the key is already + /// in the top `n` positions. If so, it returns early without a write lock. + /// Otherwise it falls through to a write lock for the actual promotion. + /// + /// **Performance note**: this optimization only helps when `n` covers a + /// significant fraction of the shard. At production scale (~100K+ items + /// per shard), hot items are rarely in the top N positions, so the + /// read-lock scan is usually wasted work that adds latency without + /// reducing contention. Benchmarks (`cargo bench --bench bench_lru`) + /// show that plain [`promote()`](Self::promote) is faster at scale. + /// Consider using `promote()` directly unless profiling shows a clear + /// benefit for your workload. /// /// Return false if the item doesn't exist pub fn promote_top_n(&self, key: u64, top: usize) -> bool { From 842ddd9fac9ee8570eb1e5b8ea208fbc88e7671c Mon Sep 17 00:00:00 2001 From: Edward Wang Date: Wed, 1 Apr 2026 15:56:13 -0700 Subject: [PATCH 23/52] Split out pingora-prometheus into a separate crate --- .bleep | 2 +- Cargo.toml | 1 + docs/user_guide/modify_filter.md | 3 +- docs/user_guide/prom.md | 20 +-- pingora-core/Cargo.toml | 2 - pingora-core/src/apps/mod.rs | 2 - pingora-core/src/apps/prometheus_http_app.rs | 66 ---------- pingora-core/src/services/listening.rs | 16 --- pingora-prometheus/Cargo.toml | 22 ++++ pingora-prometheus/src/lib.rs | 131 +++++++++++++++++++ pingora-proxy/Cargo.toml | 2 +- pingora-proxy/examples/gateway.rs | 6 +- pingora/Cargo.toml | 3 +- pingora/examples/server.rs | 7 +- 14 files changed, 167 insertions(+), 116 deletions(-) delete mode 100644 pingora-core/src/apps/prometheus_http_app.rs create mode 100644 pingora-prometheus/Cargo.toml create mode 100644 pingora-prometheus/src/lib.rs diff --git a/.bleep b/.bleep index 5269af84..7452d585 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -85c78ad06e98d7e93900693ce5135f54d2ee3341 \ No newline at end of file +860cd189e019331d6106c586765ccf8be7e5ebd2 \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index d3c8603b..c78de1f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ members = [ "pingora-ketama", "pingora-load-balancing", "pingora-memory-cache", + "pingora-prometheus", "tinyufo", ] diff --git a/docs/user_guide/modify_filter.md b/docs/user_guide/modify_filter.md index 3e5378fb..a833fc27 100644 --- a/docs/user_guide/modify_filter.md +++ b/docs/user_guide/modify_filter.md @@ -123,8 +123,7 @@ impl ProxyHttp for MyGateway { fn main() { ... - let mut prometheus_service_http = - pingora::services::listening::Service::prometheus_http_service(); + let mut prometheus_service_http = pingora_prometheus::prometheus_http_service(); prometheus_service_http.add_tcp("127.0.0.1:6192"); my_server.add_service(prometheus_service_http); diff --git a/docs/user_guide/prom.md b/docs/user_guide/prom.md index b1868f12..1e83c0f3 100644 --- a/docs/user_guide/prom.md +++ b/docs/user_guide/prom.md @@ -1,29 +1,21 @@ # Prometheus -Pingora has a built-in prometheus HTTP metric server for scraping. +The [`pingora-prometheus`](https://docs.rs/pingora-prometheus) crate provides a +Prometheus HTTP metrics server for scraping. -## Enabling Prometheus Support +## Adding the Dependency -Prometheus support is an optional feature in Pingora. To use it, you need to enable the `prometheus` feature in your `Cargo.toml`: +Add `pingora-prometheus` to your `Cargo.toml`: ```toml -# If using the main pingora crate -pingora = { version = "0.8.0", features = ["prometheus"] } - -# If using pingora-core directly -pingora-core = { version = "0.8.0", features = ["prometheus"] } - -# If using pingora-proxy crate -pingora-proxy = { version = "0.8.0", features = ["prometheus"] } +pingora-prometheus = "0.8.0" ``` ## Setting up a Prometheus Metrics Endpoint -Once the feature is enabled, you can set up a Prometheus metrics endpoint like this: - ```rust ... - let mut prometheus_service_http = Service::prometheus_http_service(); + let mut prometheus_service_http = pingora_prometheus::prometheus_http_service(); prometheus_service_http.add_tcp("0.0.0.0:1234"); my_server.add_service(prometheus_service_http); my_server.run_forever(); diff --git a/pingora-core/Cargo.toml b/pingora-core/Cargo.toml index e2854966..12ff7a23 100644 --- a/pingora-core/Cargo.toml +++ b/pingora-core/Cargo.toml @@ -47,7 +47,6 @@ strum = "0.26.2" strum_macros = "0.26.2" libc = "0.2.70" chrono = { version = "~0.4.31", features = ["alloc"], default-features = false } -prometheus = { version = "0.14", optional = true } sentry = { version = "0.36", features = [ "backtrace", "contexts", @@ -108,4 +107,3 @@ openssl_derived = ["any_tls"] any_tls = [] sentry = ["dep:sentry"] connection_filter = [] -prometheus = ["dep:prometheus"] diff --git a/pingora-core/src/apps/mod.rs b/pingora-core/src/apps/mod.rs index 8c087489..82989e5c 100644 --- a/pingora-core/src/apps/mod.rs +++ b/pingora-core/src/apps/mod.rs @@ -15,8 +15,6 @@ //! The abstraction and implementation interface for service application logic pub mod http_app; -#[cfg(feature = "prometheus")] -pub mod prometheus_http_app; use crate::server::ShutdownWatch; use async_trait::async_trait; diff --git a/pingora-core/src/apps/prometheus_http_app.rs b/pingora-core/src/apps/prometheus_http_app.rs deleted file mode 100644 index f06cce7d..00000000 --- a/pingora-core/src/apps/prometheus_http_app.rs +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2026 Cloudflare, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! An HTTP application that reports Prometheus metrics. - -#[cfg(feature = "prometheus")] -mod prometheus_impl { - use async_trait::async_trait; - use http::Response; - use prometheus::{Encoder, TextEncoder}; - - use super::super::http_app::HttpServer; - use crate::apps::http_app::ServeHttp; - use crate::modules::http::compression::ResponseCompressionBuilder; - use crate::protocols::http::ServerSession; - - /// An HTTP application that reports Prometheus metrics. - /// - /// This application will report all the [static metrics](https://docs.rs/prometheus/latest/prometheus/index.html#static-metrics) - /// collected via the [Prometheus](https://docs.rs/prometheus/) crate; - pub struct PrometheusHttpApp; - - #[async_trait] - impl ServeHttp for PrometheusHttpApp { - async fn response(&self, _http_session: &mut ServerSession) -> Response> { - let encoder = TextEncoder::new(); - let metric_families = prometheus::gather(); - let mut buffer = vec![]; - encoder.encode(&metric_families, &mut buffer).unwrap(); - Response::builder() - .status(200) - .header(http::header::CONTENT_TYPE, encoder.format_type()) - .header(http::header::CONTENT_LENGTH, buffer.len()) - .body(buffer) - .unwrap() - } - } - - /// The [HttpServer] for [PrometheusHttpApp] - /// - /// This type provides the functionality of [PrometheusHttpApp] with compression enabled - pub type PrometheusServer = HttpServer; - - impl PrometheusServer { - pub fn new() -> Self { - let mut server = Self::new_app(PrometheusHttpApp); - // enable gzip level 7 compression - server.add_module(ResponseCompressionBuilder::enable(7)); - server - } - } -} - -#[cfg(feature = "prometheus")] -pub use prometheus_impl::*; diff --git a/pingora-core/src/services/listening.rs b/pingora-core/src/services/listening.rs index b6886c21..7b718b9b 100644 --- a/pingora-core/src/services/listening.rs +++ b/pingora-core/src/services/listening.rs @@ -309,19 +309,3 @@ impl ServiceTrait for Service { self.threads } } - -#[cfg(feature = "prometheus")] -use crate::apps::prometheus_http_app::PrometheusServer; - -#[cfg(feature = "prometheus")] -impl Service { - /// The Prometheus HTTP server - /// - /// The HTTP server endpoint that reports Prometheus metrics collected in the entire service - pub fn prometheus_http_service() -> Self { - Service::new( - "Prometheus metric HTTP".to_string(), - PrometheusServer::new(), - ) - } -} diff --git a/pingora-prometheus/Cargo.toml b/pingora-prometheus/Cargo.toml new file mode 100644 index 00000000..d9701213 --- /dev/null +++ b/pingora-prometheus/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "pingora-prometheus" +version = "0.8.0" +authors = ["Pingora Team at Cloudflare "] +license = "Apache-2.0" +edition = "2021" +repository = "https://github.com/cloudflare/pingora" +categories = ["asynchronous", "network-programming"] +keywords = ["async", "http", "prometheus", "pingora"] +description = """ +A Prometheus metrics HTTP server for pingora services. +""" + +[lib] +name = "pingora_prometheus" +path = "src/lib.rs" + +[dependencies] +pingora-core = { version = "0.8.0", path = "../pingora-core", default-features = false } +prometheus = "0.14" +async-trait = { workspace = true } +http = { workspace = true } diff --git a/pingora-prometheus/src/lib.rs b/pingora-prometheus/src/lib.rs new file mode 100644 index 00000000..cfd90b89 --- /dev/null +++ b/pingora-prometheus/src/lib.rs @@ -0,0 +1,131 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![warn(clippy::all)] + +//! A Prometheus metrics HTTP server for [pingora](https://docs.rs/pingora) services. +//! +//! This crate provides [`PrometheusHttpApp`] and [`PrometheusServer`], which serve +//! all [static metrics](https://docs.rs/prometheus/latest/prometheus/index.html#static-metrics) +//! collected via the [`prometheus`] crate as an HTTP endpoint. +//! +//! # Example +//! +//! ```rust,ignore +//! use pingora_core::services::listening::Service; +//! use pingora_prometheus::new_prometheus_server; +//! +//! let mut prometheus_service = Service::new( +//! "Prometheus HTTP".to_string(), +//! new_prometheus_server(), +//! ); +//! prometheus_service.add_tcp("127.0.0.1:6150"); +//! server.add_service(prometheus_service); +//! ``` +//! +//! Or use the convenience function: +//! +//! ```rust,ignore +//! let mut prometheus_service = pingora_prometheus::prometheus_http_service(); +//! prometheus_service.add_tcp("127.0.0.1:6150"); +//! server.add_service(prometheus_service); +//! ``` + +use async_trait::async_trait; +use http::Response; +use prometheus::{Encoder, TextEncoder}; + +use pingora_core::apps::http_app::{HttpServer, ServeHttp}; +use pingora_core::modules::http::compression::ResponseCompressionBuilder; +use pingora_core::protocols::http::ServerSession; +use pingora_core::services::listening::Service; + +/// Re-export of the [`prometheus`] crate. +/// +/// Use this re-export to ensure your metrics are registered in the same +/// global registry that [`PrometheusHttpApp`] gathers from, avoiding +/// version mismatches that would cause metrics to silently not appear. +/// +/// # Example +/// +/// ```rust,ignore +/// use pingora_prometheus::prometheus::{self, register_int_counter, IntCounter}; +/// use once_cell::sync::Lazy; +/// +/// static REQUESTS: Lazy = Lazy::new(|| { +/// register_int_counter!("requests_total", "Total requests").unwrap() +/// }); +/// ``` +pub use prometheus; + +/// An HTTP application that reports Prometheus metrics. +/// +/// This application will report all the [static metrics](https://docs.rs/prometheus/latest/prometheus/index.html#static-metrics) +/// collected via the [Prometheus](https://docs.rs/prometheus/) crate. +/// +/// Currently serves metrics on all request paths. By convention, Prometheus +/// scrapers expect metrics at `/metrics`. Since this app is typically bound +/// to a dedicated listener address, this works in practice, but callers +/// should be aware of this if sharing the listener with other routes. +// TODO: consider restricting to `/metrics` and returning 404 for other paths +pub struct PrometheusHttpApp; + +#[async_trait] +impl ServeHttp for PrometheusHttpApp { + async fn response(&self, _http_session: &mut ServerSession) -> Response> { + let encoder = TextEncoder::new(); + let metric_families = prometheus::gather(); + let mut buffer = vec![]; + encoder.encode(&metric_families, &mut buffer).unwrap(); + Response::builder() + .status(200) + .header(http::header::CONTENT_TYPE, encoder.format_type()) + .header(http::header::CONTENT_LENGTH, buffer.len()) + .body(buffer) + .unwrap() + } +} + +/// The [`HttpServer`] for [`PrometheusHttpApp`]. +/// +/// This type provides the functionality of [`PrometheusHttpApp`] with gzip +/// compression enabled (level 7). +pub type PrometheusServer = HttpServer; + +/// Create a new [`PrometheusServer`] with compression enabled. +pub fn new_prometheus_server() -> PrometheusServer { + let mut server = PrometheusServer::new_app(PrometheusHttpApp); + // enable gzip level 7 compression + server.add_module(ResponseCompressionBuilder::enable(7)); + server +} + +/// Create a Prometheus HTTP [`Service`] ready to have endpoints added. +/// +/// This is a convenience function that creates a [`Service`] wrapping a +/// [`PrometheusServer`] with compression enabled. +/// +/// # Example +/// +/// ```rust,ignore +/// let mut prometheus_service = pingora_prometheus::prometheus_http_service(); +/// prometheus_service.add_tcp("127.0.0.1:6150"); +/// server.add_service(prometheus_service); +/// ``` +pub fn prometheus_http_service() -> Service { + Service::new( + "Prometheus metric HTTP".to_string(), + new_prometheus_server(), + ) +} diff --git a/pingora-proxy/Cargo.toml b/pingora-proxy/Cargo.toml index d4df1378..d82179cb 100644 --- a/pingora-proxy/Cargo.toml +++ b/pingora-proxy/Cargo.toml @@ -47,6 +47,7 @@ hyper = "0.14" tokio-tungstenite = "0.20.1" pingora-limits = { version = "0.8.0", path = "../pingora-limits" } pingora-load-balancing = { version = "0.8.0", path = "../pingora-load-balancing", default-features=false } +pingora-prometheus = { version = "0.8.0", path = "../pingora-prometheus" } prometheus = "0" futures-util = "0.3" serde = { version = "1.0", features = ["derive"] } @@ -71,7 +72,6 @@ any_tls = [] sentry = ["pingora-core/sentry"] adjust_upstream_modules = [] connection_filter = ["pingora-core/connection_filter"] -prometheus = ["pingora-core/prometheus"] trace = ["pingora-cache/trace"] [[example]] diff --git a/pingora-proxy/examples/gateway.rs b/pingora-proxy/examples/gateway.rs index e320688f..79c1646c 100644 --- a/pingora-proxy/examples/gateway.rs +++ b/pingora-proxy/examples/gateway.rs @@ -129,12 +129,8 @@ fn main() { my_proxy.add_tcp("0.0.0.0:6191"); my_server.add_service(my_proxy); - #[cfg(feature = "prometheus")] - let mut prometheus_service_http = - pingora_core::services::listening::Service::prometheus_http_service(); - #[cfg(feature = "prometheus")] + let mut prometheus_service_http = pingora_prometheus::prometheus_http_service(); prometheus_service_http.add_tcp("127.0.0.1:6192"); - #[cfg(feature = "prometheus")] my_server.add_service(prometheus_service_http); my_server.run_forever(); diff --git a/pingora/Cargo.toml b/pingora/Cargo.toml index d9fb57c2..7de8640f 100644 --- a/pingora/Cargo.toml +++ b/pingora/Cargo.toml @@ -29,6 +29,7 @@ pingora-load-balancing = { version = "0.8.0", path = "../pingora-load-balancing" pingora-proxy = { version = "0.8.0", path = "../pingora-proxy", optional = true, default-features = false } pingora-cache = { version = "0.8.0", path = "../pingora-cache", optional = true, default-features = false } + # Only used for documenting features, but doesn't work in any other dependency # group :( document-features = { version = "0.2.10", optional = true } @@ -42,6 +43,7 @@ hyper = "0.14" async-trait = { workspace = true } http = { workspace = true } log = { workspace = true } +pingora-prometheus = { version = "0.8.0", path = "../pingora-prometheus" } prometheus = "0.14" once_cell = { workspace = true } bytes = { workspace = true } @@ -152,5 +154,4 @@ document-features = [ "sentry", "connection_filter" ] -prometheus = ["pingora-core/prometheus"] trace = ["pingora-cache?/trace", "pingora-proxy?/trace"] diff --git a/pingora/examples/server.rs b/pingora/examples/server.rs index 1e299140..37a246cd 100644 --- a/pingora/examples/server.rs +++ b/pingora/examples/server.rs @@ -20,8 +20,6 @@ use pingora::protocols::TcpKeepalive; use pingora::server::configuration::Opt; use pingora::server::{Server, ShutdownWatch}; use pingora::services::background::{background_service, BackgroundService}; -#[cfg(feature = "prometheus")] -use pingora::services::listening::Service as ListeningService; use pingora::services::ServiceWithDependents; use async_trait::async_trait; @@ -187,9 +185,7 @@ pub fn main() { &key_path, ); - #[cfg(feature = "prometheus")] - let mut prometheus_service_http = ListeningService::prometheus_http_service(); - #[cfg(feature = "prometheus")] + let mut prometheus_service_http = pingora_prometheus::prometheus_http_service(); prometheus_service_http.add_tcp("127.0.0.1:6150"); let background_service = background_service("example", ExampleBackgroundService {}); @@ -199,7 +195,6 @@ pub fn main() { Box::new(echo_service_http), Box::new(proxy_service), Box::new(proxy_service_ssl), - #[cfg(feature = "prometheus")] Box::new(prometheus_service_http), Box::new(background_service), ]; From 211405690f6adab55e585e2edd9121d6891f0a6f Mon Sep 17 00:00:00 2001 From: ewang Date: Fri, 3 Apr 2026 18:56:20 -0700 Subject: [PATCH 24/52] Make h2 stream window and conn window size configurable --- .bleep | 2 +- pingora-core/src/connectors/http/v2.rs | 203 ++++++++++++++++++++-- pingora-core/src/protocols/http/v2/mod.rs | 5 +- pingora-core/src/upstreams/peer.rs | 14 +- 4 files changed, 209 insertions(+), 15 deletions(-) diff --git a/.bleep b/.bleep index 7452d585..b2b6ca2a 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -860cd189e019331d6106c586765ccf8be7e5ebd2 \ No newline at end of file +e2089546a5962c0f65c081211d604dadd9330195 \ No newline at end of file diff --git a/pingora-core/src/connectors/http/v2.rs b/pingora-core/src/connectors/http/v2.rs index dd1d2b27..0b70b66e 100644 --- a/pingora-core/src/connectors/http/v2.rs +++ b/pingora-core/src/connectors/http/v2.rs @@ -343,8 +343,13 @@ impl Connector { // the caller that the server speaks h2c } } - let max_h2_stream = peer.get_peer_options().map_or(1, |o| o.max_h2_streams); - let conn = handshake(stream, max_h2_stream, peer.h2_ping_interval()).await?; + let peer_options = peer.get_peer_options(); + let mut settings = H2HandshakeSettings::new(); + settings.max_streams = peer_options.map_or(1, |o| o.max_h2_streams); + settings.ping_interval = peer.h2_ping_interval(); + settings.stream_window_size = peer_options.and_then(|o| o.h2_stream_window_size); + settings.connection_window_size = peer_options.and_then(|o| o.h2_connection_window_size); + let conn = handshake(stream, settings).await?; let h2_stream = conn .spawn_stream() .await? @@ -484,19 +489,84 @@ impl Connector { // 8 Mbytes = 80 Mbytes X 100ms, which should be enough for most links. const H2_WINDOW_SIZE: u32 = 1 << 23; -pub async fn handshake( - stream: Stream, - max_streams: usize, - h2_ping_interval: Option, -) -> Result { +/// Maximum allowed H2 window size per [RFC 9113 §6.9.1](https://datatracker.ietf.org/doc/html/rfc9113#section-6.9.1-7). +const H2_MAX_WINDOW_SIZE: u32 = (1u32 << 31) - 1; + +/// Settings for HTTP/2 handshake. +/// +/// # Example +/// +/// ```rust,ignore +/// use pingora_core::connectors::http::v2::{handshake, H2HandshakeSettings}; +/// +/// // With custom window sizes +/// let mut settings = H2HandshakeSettings::new(); +/// settings.max_streams = 100; +/// settings.stream_window_size = Some(1 << 20); // 1MiB +/// settings.connection_window_size = Some(1 << 24); // 16MiB +/// let conn = handshake(stream, settings).await?; +/// ``` +#[derive(Debug, Clone, Default)] +#[non_exhaustive] +pub struct H2HandshakeSettings { + /// The maximum number of concurrent streams allowed on this connection. + pub max_streams: usize, + /// Optional interval for sending H2 ping frames to keep the connection alive. + pub ping_interval: Option, + /// Optional initial per-stream receive window size in bytes. + /// If `None`, the default of 8MB is used. + pub stream_window_size: Option, + /// Optional initial connection-level receive window size in bytes. + /// If `None`, the default of 8MB is used. + pub connection_window_size: Option, +} + +impl H2HandshakeSettings { + /// Create a new `H2HandshakeSettings` with all defaults. + pub fn new() -> Self { + Self::default() + } +} + +/// Perform an HTTP/2 handshake on the given stream with the given settings. +pub async fn handshake(stream: Stream, settings: H2HandshakeSettings) -> Result { use h2::client::Builder; use pingora_runtime::current_handle; + let max_streams = settings.max_streams; + // Safe guard: new_http_session() assumes there should be at least one free stream if max_streams == 0 { return Error::e_explain(H2Error, "zero max_stream configured"); } + // Validate window sizes against RFC 9113 §6.9.1 limit + // https://datatracker.ietf.org/doc/html/rfc9113#section-6.9.1-7 + if settings + .stream_window_size + .is_some_and(|w| w == 0 || w > H2_MAX_WINDOW_SIZE) + { + return Error::e_explain( + H2Error, + format!( + "stream_window_size must be between 1 and {} (2^31-1)", + H2_MAX_WINDOW_SIZE + ), + ); + } + if settings + .connection_window_size + .is_some_and(|w| w == 0 || w > H2_MAX_WINDOW_SIZE) + { + return Error::e_explain( + H2Error, + format!( + "connection_window_size must be between 1 and {} (2^31-1)", + H2_MAX_WINDOW_SIZE + ), + ); + } + let id = stream.id(); let digest = Digest { // NOTE: this field is always false because the digest is shared across all streams @@ -507,16 +577,16 @@ pub async fn handshake( proxy_digest: stream.get_proxy_digest(), socket_digest: stream.get_socket_digest(), }; - // TODO: make these configurable + let stream_window = settings.stream_window_size.unwrap_or(H2_WINDOW_SIZE); + let conn_window = settings.connection_window_size.unwrap_or(H2_WINDOW_SIZE); let (send_req, connection) = Builder::new() .enable_push(false) .initial_max_send_streams(max_streams) // The limit for the server. Server push is not allowed, so this value doesn't matter .max_concurrent_streams(1) .max_frame_size(64 * 1024) // advise server to send larger frames - .initial_window_size(H2_WINDOW_SIZE) - // should this be max_streams * H2_WINDOW_SIZE? - .initial_connection_window_size(H2_WINDOW_SIZE) + .initial_window_size(stream_window) + .initial_connection_window_size(conn_window) .handshake(stream) .await .or_err(HandshakeError, "during H2 handshake")?; @@ -538,7 +608,7 @@ pub async fn handshake( connection, id, closed_tx, - h2_ping_interval, + settings.ping_interval, ping_timeout_clone, ) .await; @@ -558,6 +628,9 @@ pub async fn handshake( mod tests { use super::*; use crate::upstreams::peer::HttpPeer; + use bytes::Bytes; + use http::{Response, StatusCode}; + use pingora_http::RequestHeader; #[tokio::test] #[cfg(feature = "any_tls")] @@ -818,4 +891,110 @@ mod tests { .unwrap() .is_none()); } + + #[tokio::test] + async fn test_h2_handshake_settings_validation() { + use super::H2HandshakeSettings; + + // Test zero stream window size is rejected + let mut settings = H2HandshakeSettings::new(); + settings.max_streams = 100; + settings.stream_window_size = Some(0); + let (client, _server) = tokio::io::duplex(65536); + match handshake(Box::new(client), settings).await { + Err(e) => assert!( + e.to_string() + .contains("stream_window_size must be between 1"), + "Unexpected error: {}", + e + ), + Ok(_) => panic!("Expected error for stream_window_size = 0"), + } + + // Test zero connection window size is rejected + let mut settings = H2HandshakeSettings::new(); + settings.max_streams = 100; + settings.connection_window_size = Some(0); + let (client, _server) = tokio::io::duplex(65536); + match handshake(Box::new(client), settings).await { + Err(e) => assert!( + e.to_string() + .contains("connection_window_size must be between 1"), + "Unexpected error: {}", + e + ), + Ok(_) => panic!("Expected error for connection_window_size = 0"), + } + + // Test exceeding max stream window size is rejected + let mut settings = H2HandshakeSettings::new(); + settings.max_streams = 100; + settings.stream_window_size = Some(super::H2_MAX_WINDOW_SIZE + 1); + let (client, _server) = tokio::io::duplex(65536); + match handshake(Box::new(client), settings).await { + Err(e) => assert!( + e.to_string() + .contains("stream_window_size must be between 1"), + "Unexpected error: {}", + e + ), + Ok(_) => panic!("Expected error for stream_window_size > max"), + } + + // Test exceeding max connection window size is rejected + let mut settings = H2HandshakeSettings::new(); + settings.max_streams = 100; + settings.connection_window_size = Some(super::H2_MAX_WINDOW_SIZE + 1); + let (client, _server) = tokio::io::duplex(65536); + match handshake(Box::new(client), settings).await { + Err(e) => assert!( + e.to_string() + .contains("connection_window_size must be between 1"), + "Unexpected error: {}", + e + ), + Ok(_) => panic!("Expected error for connection_window_size > max"), + } + } + + #[tokio::test] + async fn test_h2_handshake_custom_window_sizes() { + // Test that valid custom window sizes are accepted and handshake succeeds + let mut settings = H2HandshakeSettings::new(); + settings.max_streams = 100; + settings.stream_window_size = Some(1 << 20); // 1MiB + settings.connection_window_size = Some(1 << 24); // 16MiB + + let (client, server) = tokio::io::duplex(65536); + + // Spawn server side + tokio::spawn(async move { + let mut server_conn = h2::server::handshake(server).await.unwrap(); + if let Some(result) = server_conn.accept().await { + let (_request, mut respond) = result.unwrap(); + let resp = Response::builder().status(StatusCode::OK).body(()).unwrap(); + let mut stream = respond.send_response(resp, false).unwrap(); + stream.send_data(Bytes::from("ok"), true).unwrap(); + server_conn.graceful_shutdown(); + } + // Drive the server connection until the client closes + while let Some(_res) = server_conn.accept().await {} + }); + + // Client side - should succeed with custom window sizes + let conn = handshake(Box::new(client), settings).await.unwrap(); + + // Verify we can spawn a stream and complete a request/response cycle + let mut stream = conn.spawn_stream().await.unwrap().unwrap(); + let mut request = RequestHeader::build("GET", b"/", None).unwrap(); + request + .insert_header(http::header::HOST, "example.com") + .unwrap(); + stream + .write_request_header(Box::new(request), true) + .unwrap(); + + stream.read_response_header().await.unwrap(); + assert_eq!(stream.response_header().unwrap().status, 200); + } } diff --git a/pingora-core/src/protocols/http/v2/mod.rs b/pingora-core/src/protocols/http/v2/mod.rs index 01711807..615fcee5 100644 --- a/pingora-core/src/protocols/http/v2/mod.rs +++ b/pingora-core/src/protocols/http/v2/mod.rs @@ -111,7 +111,10 @@ mod test { // Client handles.push(tokio::spawn(async move { - let conn = crate::connectors::http::v2::handshake(Box::new(client), 500, None) + use crate::connectors::http::v2::H2HandshakeSettings; + let mut settings = H2HandshakeSettings::new(); + settings.max_streams = 500; + let conn = crate::connectors::http::v2::handshake(Box::new(client), settings) .await .unwrap(); diff --git a/pingora-core/src/upstreams/peer.rs b/pingora-core/src/upstreams/peer.rs index c9ae0a66..78c6dbcc 100644 --- a/pingora-core/src/upstreams/peer.rs +++ b/pingora-core/src/upstreams/peer.rs @@ -431,8 +431,14 @@ pub struct PeerOptions { pub s2n_security_policy: Option, #[cfg(feature = "s2n")] pub max_blinding_delay: Option, - // how many concurrent h2 stream are allowed in the same connection + /// How many concurrent h2 streams are allowed in the same connection. pub max_h2_streams: usize, + /// Initial per-stream H2 receive window size in bytes. + /// If `None`, the default of 8MB is used. + pub h2_stream_window_size: Option, + /// Initial connection-level H2 receive window size in bytes. + /// If `None`, the default of 8MB is used. + pub h2_connection_window_size: Option, /// Allow invalid Content-Length in HTTP/1 responses (non-RFC compliant). /// /// When enabled, invalid Content-Length responses are treated as close-delimited responses. @@ -494,6 +500,8 @@ impl PeerOptions { #[cfg(feature = "s2n")] max_blinding_delay: None, max_h2_streams: 1, + h2_stream_window_size: None, + h2_connection_window_size: None, allow_h1_response_invalid_content_length: false, extra_proxy_headers: BTreeMap::new(), curves: None, @@ -685,6 +693,10 @@ impl Hash for HttpPeer { self.group_key.hash(state); // max h2 stream settings self.options.max_h2_streams.hash(state); + // h2_stream_window_size and h2_connection_window_size are intentionally excluded + // from the reuse hash for now. These are per-connection settings applied at handshake + // time and may be revisited alongside other h2 settings that could be dynamically + // adjusted over the lifetime of a connection. } } From c0adfd32c216a3bec14371ec4467236f34a6f9db Mon Sep 17 00:00:00 2001 From: Edward Wang Date: Fri, 17 Apr 2026 14:41:46 -0700 Subject: [PATCH 25/52] Ignore caching stall tests for CI flakiness Also temp ignore the active RUSTSECs until the internal dependency bumps are synced. --- .cargo/audit.toml | 3 +++ pingora-proxy/tests/test_upstream.rs | 2 ++ 2 files changed, 5 insertions(+) create mode 100644 .cargo/audit.toml diff --git a/.cargo/audit.toml b/.cargo/audit.toml new file mode 100644 index 00000000..7c6e098f --- /dev/null +++ b/.cargo/audit.toml @@ -0,0 +1,3 @@ +[advisories] +# Temp before internal sync applies dependency bumps +ignore = ["RUSTSEC-2026-0097", "RUSTSEC-2026-0098", "RUSTSEC-2026-0099"] diff --git a/pingora-proxy/tests/test_upstream.rs b/pingora-proxy/tests/test_upstream.rs index cdba09b7..ff6453d4 100644 --- a/pingora-proxy/tests/test_upstream.rs +++ b/pingora-proxy/tests/test_upstream.rs @@ -2914,6 +2914,7 @@ mod test_cache { } #[tokio::test] + #[ignore = "flaky in CI due to timing/resource contention"] async fn test_caching_when_downstream_stalls() { use std::net::ToSocketAddrs; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -2997,6 +2998,7 @@ mod test_cache { // to the origin over H2 (via the x-h2 header). // #[tokio::test] + #[ignore = "flaky in CI due to timing/resource contention"] async fn test_caching_h2_upstream_when_downstream_stalls() { use std::net::ToSocketAddrs; use tokio::io::{AsyncReadExt, AsyncWriteExt}; From 452813e6b4e03d18779eb81ecd7eb1dc508ba7bf Mon Sep 17 00:00:00 2001 From: Hrushikesh Deshpande Date: Thu, 23 Apr 2026 17:52:06 -0400 Subject: [PATCH 26/52] ci: add Semgrep OSS scanning workflow --- .github/workflows/semgrep.yml | 40 ++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/.github/workflows/semgrep.yml b/.github/workflows/semgrep.yml index b40314b3..3ae3dd57 100644 --- a/.github/workflows/semgrep.yml +++ b/.github/workflows/semgrep.yml @@ -1,24 +1,30 @@ +name: Semgrep OSS scan on: pull_request: {} + push: + branches: [main, master] workflow_dispatch: {} - push: - branches: - - main - - master schedule: - - cron: '0 0 * * *' -name: Semgrep config + - cron: '0 0 15 * *' +concurrency: + group: semgrep-${{ github.event_name }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true +permissions: + contents: read jobs: semgrep: - name: semgrep/ci - runs-on: ubuntu-latest - env: - SEMGREP_APP_TOKEN: ${{ secrets.SEMGREP_APP_TOKEN }} - SEMGREP_URL: https://cloudflare.semgrep.dev - SEMGREP_APP_URL: https://cloudflare.semgrep.dev - SEMGREP_VERSION_CHECK_URL: https://cloudflare.semgrep.dev/api/check-version - container: - image: returntocorp/semgrep + name: semgrep-oss + runs-on: ubuntu-slim steps: - - uses: actions/checkout@v4 - - run: semgrep ci + - uses: actions/checkout@v5 + with: + fetch-depth: 1 + - id: cache-semgrep + uses: actions/cache@v5 + with: + path: ~/.local + key: semgrep-1.160.0-${{ runner.os }} + - if: steps.cache-semgrep.outputs.cache-hit != 'true' + run: pip install --user semgrep==1.160.0 + - run: echo "$HOME/.local/bin" >> "$GITHUB_PATH" + - run: semgrep scan --config=auto From d4e4ae156a484e0ceeee8875c4e337228fd90c84 Mon Sep 17 00:00:00 2001 From: Matthew Gumport Date: Thu, 9 Apr 2026 20:10:44 +0000 Subject: [PATCH 27/52] vary on available-dictionary Dictionary-compressed responses should vary on Available-Dictionary (RFC 9842) so caches don't serve them to mismatched clients. This adds the header in the compression module. --- .bleep | 2 +- .../src/protocols/http/compression/mod.rs | 71 ++++++++++++++++++- 2 files changed, 69 insertions(+), 4 deletions(-) diff --git a/.bleep b/.bleep index b2b6ca2a..195ccfea 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -e2089546a5962c0f65c081211d604dadd9330195 \ No newline at end of file +e149ce1ed42428ee3956b6fd572bdaddd46d7c5c \ No newline at end of file diff --git a/pingora-core/src/protocols/http/compression/mod.rs b/pingora-core/src/protocols/http/compression/mod.rs index 9e8e96fb..301de0fb 100644 --- a/pingora-core/src/protocols/http/compression/mod.rs +++ b/pingora-core/src/protocols/http/compression/mod.rs @@ -303,9 +303,18 @@ impl ResponseCompressionCtx { Action::Compress(algorithm) => { let idx = algorithm.index(); let compressor = match algorithm { - Algorithm::Dcz => dictionary.as_ref().and_then(|d| { - algorithm.maybe_compressor_with_dictionary(levels[idx], d) - }), + Algorithm::Dcz => { + // RFC 9842: dictionary-compressed responses vary on + // Available-Dictionary so caches don't serve this variant + // to clients with a different or missing dictionary. + let enc = dictionary.as_ref().and_then(|d| { + algorithm.maybe_compressor_with_dictionary(levels[idx], d) + }); + if enc.is_some() { + add_vary_header(resp, &AVAILABLE_DICTIONARY); + } + enc + } _ => algorithm.compressor(levels[idx]), }; (compressor, preserve_etag[idx]) @@ -780,6 +789,13 @@ fn compressible(resp: &ResponseHeader) -> bool { } } +/// Header name for the Available-Dictionary request header ([RFC 9842]). +/// TODO: Replace with http::header when available. +/// +/// [RFC 9842]: https://datatracker.ietf.org/doc/html/rfc9842 +static AVAILABLE_DICTIONARY: http::HeaderName = + http::HeaderName::from_static("available-dictionary"); + // add Vary header with the specified value or extend an existing Vary header value fn add_vary_header(resp: &mut ResponseHeader, value: &http::header::HeaderName) { use http::header::{HeaderValue, VARY}; @@ -1055,6 +1071,11 @@ mod tests_dictionary_compression { resp.headers.get("content-encoding").unwrap().as_bytes(), b"dcz" ); + // RFC 9842: DCZ responses must vary on Available-Dictionary. + assert!(resp.headers.get_all("vary").iter().any(|v| v + .as_bytes() + .split(|b| *b == b',') + .any(|t| t.trim_ascii().eq_ignore_ascii_case(b"available-dictionary")))); let input = Bytes::from_static(b"The quick brown fox jumps over the lazy dog again."); let compressed = ctx.response_body_filter(Some(&input), true).unwrap(); @@ -1080,6 +1101,11 @@ mod tests_dictionary_compression { // no dictionary set, no compression applied assert!(resp.headers.get("content-encoding").is_none()); + // No compression → no Vary: available-dictionary. + assert!(!resp.headers.get_all("vary").iter().any(|v| v + .as_bytes() + .split(|b| *b == b',') + .any(|t| t.trim_ascii().eq_ignore_ascii_case(b"available-dictionary")))); } #[test] @@ -1099,6 +1125,11 @@ mod tests_dictionary_compression { // dcz first but no dictionary, no automatic fallback assert!(resp.headers.get("content-encoding").is_none()); + // No compression → no Vary: available-dictionary. + assert!(!resp.headers.get_all("vary").iter().any(|v| v + .as_bytes() + .split(|b| *b == b',') + .any(|t| t.trim_ascii().eq_ignore_ascii_case(b"available-dictionary")))); } #[test] @@ -1152,6 +1183,11 @@ mod tests_dictionary_compression { resp.headers.get("transfer-encoding").unwrap().as_bytes(), b"chunked" ); + // RFC 9842: DCZ responses must vary on Available-Dictionary. + assert!(resp.headers.get_all("vary").iter().any(|v| v + .as_bytes() + .split(|b| *b == b',') + .any(|t| t.trim_ascii().eq_ignore_ascii_case(b"available-dictionary")))); let chunk1 = Bytes::from_static(b"First chunk. "); let output1 = ctx.response_body_filter(Some(&chunk1), false); @@ -1166,4 +1202,33 @@ mod tests_dictionary_compression { assert_eq!(total_in, chunk1.len() + chunk2.len()); assert!(total_out > 0); } + + #[test] + fn regular_compression_no_available_dictionary_vary() { + // Gzip compression should produce Vary: Accept-Encoding but NOT + // Vary: available-dictionary. + let mut ctx = ResponseCompressionCtx::new(3, false, false); + + let mut req = RequestHeader::build("GET", b"/page.html", None).unwrap(); + req.insert_header("accept-encoding", "gzip").unwrap(); + ctx.request_filter(&req); + + let mut resp = ResponseHeader::build(200, None).unwrap(); + resp.insert_header("content-type", "text/html").unwrap(); + resp.insert_header("content-length", "1000").unwrap(); + ctx.response_header_filter(&mut resp, false); + + assert_eq!( + resp.headers.get("content-encoding").unwrap().as_bytes(), + b"gzip" + ); + assert!(resp.headers.get_all("vary").iter().any(|v| v + .as_bytes() + .split(|b| *b == b',') + .any(|t| t.trim_ascii().eq_ignore_ascii_case(b"accept-encoding")))); + assert!(!resp.headers.get_all("vary").iter().any(|v| v + .as_bytes() + .split(|b| *b == b',') + .any(|t| t.trim_ascii().eq_ignore_ascii_case(b"available-dictionary")))); + } } From 6ac51b38b9ffa762223983cf39027c2808c03551 Mon Sep 17 00:00:00 2001 From: Edward Wang Date: Thu, 9 Apr 2026 22:56:29 -0700 Subject: [PATCH 28/52] Add upstream module system This is analogous to the downstream modules but can apply prior to upstream compression. --- .bleep | 2 +- pingora-proxy/Cargo.toml | 2 +- pingora-proxy/src/lib.rs | 58 +++++++++++++++++++++++++++++++ pingora-proxy/src/proxy_custom.rs | 4 ++- pingora-proxy/src/proxy_h1.rs | 4 ++- pingora-proxy/src/proxy_h2.rs | 4 ++- pingora-proxy/src/proxy_trait.rs | 21 +++++++++-- pingora/Cargo.toml | 9 ++--- 8 files changed, 93 insertions(+), 11 deletions(-) diff --git a/.bleep b/.bleep index 195ccfea..da475ccc 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -e149ce1ed42428ee3956b6fd572bdaddd46d7c5c \ No newline at end of file +47de36f8b278e2c5624d50aa9d00ab98d52b35b5 \ No newline at end of file diff --git a/pingora-proxy/Cargo.toml b/pingora-proxy/Cargo.toml index d82179cb..c2e579c9 100644 --- a/pingora-proxy/Cargo.toml +++ b/pingora-proxy/Cargo.toml @@ -70,7 +70,7 @@ s2n = ["pingora-core/s2n", "pingora-cache/s2n", "any_tls"] openssl_derived = ["any_tls"] any_tls = [] sentry = ["pingora-core/sentry"] -adjust_upstream_modules = [] +upstream_modules = [] connection_filter = ["pingora-core/connection_filter"] trace = ["pingora-cache/trace"] diff --git a/pingora-proxy/src/lib.rs b/pingora-proxy/src/lib.rs index e5433efa..4ce9e5e5 100644 --- a/pingora-proxy/src/lib.rs +++ b/pingora-proxy/src/lib.rs @@ -119,6 +119,8 @@ where pub server_options: Option, pub h2_options: Option, pub downstream_modules: HttpModules, + #[cfg(feature = "upstream_modules")] + pub upstream_modules: HttpModules, max_retries: usize, process_custom_session: Option>, } @@ -153,6 +155,8 @@ impl HttpProxy { server_options: None, h2_options: None, downstream_modules: HttpModules::new(), + #[cfg(feature = "upstream_modules")] + upstream_modules: HttpModules::new(), max_retries: conf.max_retries, process_custom_session: None, } @@ -184,6 +188,8 @@ where shutdown_flag: Arc::new(AtomicBool::new(false)), server_options, downstream_modules: HttpModules::new(), + #[cfg(feature = "upstream_modules")] + upstream_modules: HttpModules::new(), max_retries: conf.max_retries, process_custom_session: on_custom, h2_options: None, @@ -215,6 +221,8 @@ where { self.inner .init_downstream_modules(&mut self.downstream_modules); + #[cfg(feature = "upstream_modules")] + self.inner.init_upstream_modules(&mut self.upstream_modules); } async fn handle_new_request( @@ -475,6 +483,10 @@ pub struct Session { pub subrequest_spawner: Option, // Downstream filter modules pub downstream_modules_ctx: HttpModuleCtx, + /// Upstream filter modules. These run before `upstream_compression` and see the raw + /// (pre-compression) upstream response body. + #[cfg(feature = "upstream_modules")] + pub upstream_modules_ctx: HttpModuleCtx, /// Upstream response body bytes received (payload only). Set by proxy layer. /// TODO: move this into an upstream session digest for future fields. upstream_body_bytes_received: usize, @@ -488,6 +500,7 @@ impl Session { fn new( downstream_session: impl Into>, downstream_modules: &HttpModules, + #[cfg(feature = "upstream_modules")] upstream_modules: &HttpModules, shutdown_flag: Arc, ) -> Self { Session { @@ -500,6 +513,8 @@ impl Session { subrequest_ctx: None, subrequest_spawner: None, // optionally set later on downstream_modules_ctx: downstream_modules.build_ctx(), + #[cfg(feature = "upstream_modules")] + upstream_modules_ctx: upstream_modules.build_ctx(), upstream_body_bytes_received: 0, upstream_write_pending_time: Duration::ZERO, shutdown_flag, @@ -515,6 +530,8 @@ impl Session { Self::new( Box::new(HttpSession::new_http1(stream)), &modules, + #[cfg(feature = "upstream_modules")] + &HttpModules::new(), Arc::new(AtomicBool::new(false)), ) } @@ -527,10 +544,47 @@ impl Session { Self::new( Box::new(HttpSession::new_http1(stream)), downstream_modules, + #[cfg(feature = "upstream_modules")] + &HttpModules::new(), Arc::new(AtomicBool::new(false)), ) } + /// Run upstream module filters on the given [`HttpTask`]. + /// + /// Upstream modules process each task **before** `upstream_compression` and + /// see the raw (pre-compression) upstream response. Like the downstream + /// module path, `response_trailer_filter` and `response_done_filter` return + /// values are converted to body tasks when present. + #[cfg(feature = "upstream_modules")] + pub async fn upstream_modules_filter_task(&mut self, t: &mut HttpTask) -> Result<()> { + match t { + HttpTask::Header(header, eos) => { + self.upstream_modules_ctx + .response_header_filter(header, *eos) + .await?; + } + HttpTask::Body(body, eos) | HttpTask::UpgradedBody(body, eos) => { + self.upstream_modules_ctx.response_body_filter(body, *eos)?; + } + HttpTask::Trailer(trailers) => { + if let Some(buf) = self + .upstream_modules_ctx + .response_trailer_filter(trailers)? + { + *t = HttpTask::Body(Some(buf), true); + } + } + HttpTask::Done => { + if let Some(buf) = self.upstream_modules_ctx.response_done_filter()? { + *t = HttpTask::Body(Some(buf), true); + } + } + HttpTask::Failed(_) => {} + } + Ok(()) + } + pub fn as_downstream_mut(&mut self) -> &mut HttpSession { &mut self.downstream_session } @@ -1099,6 +1153,8 @@ where Some(downstream_session) => Session::new( downstream_session, &self.downstream_modules, + #[cfg(feature = "upstream_modules")] + &self.upstream_modules, self.shutdown_flag.clone(), ), None => return, // bad request @@ -1226,6 +1282,8 @@ where Some(downstream_session) => Session::new( downstream_session, &self.downstream_modules, + #[cfg(feature = "upstream_modules")] + &self.upstream_modules, self.shutdown_flag.clone(), ), None => return None, // bad request diff --git a/pingora-proxy/src/proxy_custom.rs b/pingora-proxy/src/proxy_custom.rs index b7ee1d50..49f430be 100644 --- a/pingora-proxy/src/proxy_custom.rs +++ b/pingora-proxy/src/proxy_custom.rs @@ -293,12 +293,14 @@ where // skip downstream filtering entirely as the 304 will not be sent break; } - #[cfg(feature = "adjust_upstream_modules")] + #[cfg(feature = "upstream_modules")] if let HttpTask::Header(header, end_of_stream) = &t { self.inner .adjust_upstream_modules(session, header, *end_of_stream, ctx) .await?; } + #[cfg(feature = "upstream_modules")] + session.upstream_modules_filter_task(&mut t).await?; session.upstream_compression.response_filter(&mut t); // check error and abort // otherwise the error is surfaced via write_response_tasks() diff --git a/pingora-proxy/src/proxy_h1.rs b/pingora-proxy/src/proxy_h1.rs index dbf6e5ca..8222ec67 100644 --- a/pingora-proxy/src/proxy_h1.rs +++ b/pingora-proxy/src/proxy_h1.rs @@ -309,12 +309,14 @@ where // skip downstream filtering entirely as the 304 will not be sent break; } - #[cfg(feature = "adjust_upstream_modules")] + #[cfg(feature = "upstream_modules")] if let HttpTask::Header(header, end_of_stream) = &t { self.inner .adjust_upstream_modules(session, header, *end_of_stream, ctx) .await?; } + #[cfg(feature = "upstream_modules")] + session.upstream_modules_filter_task(&mut t).await?; session.upstream_compression.response_filter(&mut t); let task = self .h1_response_filter(session, t, ctx, serve_from_cache, range_body_filter, false) diff --git a/pingora-proxy/src/proxy_h2.rs b/pingora-proxy/src/proxy_h2.rs index afe58a0b..2fd74f60 100644 --- a/pingora-proxy/src/proxy_h2.rs +++ b/pingora-proxy/src/proxy_h2.rs @@ -306,12 +306,14 @@ where // skip downstream filtering entirely as the 304 will not be sent break; } - #[cfg(feature = "adjust_upstream_modules")] + #[cfg(feature = "upstream_modules")] if let HttpTask::Header(header, end_of_stream) = &t { self.inner .adjust_upstream_modules(session, header, *end_of_stream, ctx) .await?; } + #[cfg(feature = "upstream_modules")] + session.upstream_modules_filter_task(&mut t).await?; session.upstream_compression.response_filter(&mut t); // check error and abort // otherwise the error is surfaced via write_response_tasks() diff --git a/pingora-proxy/src/proxy_trait.rs b/pingora-proxy/src/proxy_trait.rs index b81fbb9b..2411092d 100644 --- a/pingora-proxy/src/proxy_trait.rs +++ b/pingora-proxy/src/proxy_trait.rs @@ -57,6 +57,23 @@ pub trait ProxyHttp { modules.add_module(ResponseCompressionBuilder::enable(0)); } + /// Set up upstream modules. + /// + /// In this phase, users can add [HttpModules] that will process upstream responses + /// **before** `upstream_compression`. This is the correct place to register modules + /// that need to observe the raw (pre-compression) upstream response body, such as + /// a dictionary store for shared dictionary compression. + /// + /// Upstream modules are ordered by [`HttpModuleBuilder::order()`]: higher values run + /// first. They are invoked on each upstream response task (header, body, trailers) + /// before `upstream_compression` processes the task. + /// + /// By default this method does nothing. + /// + /// This method requires the `upstream_modules` feature to be enabled. + #[cfg(feature = "upstream_modules")] + fn init_upstream_modules(&self, _modules: &mut HttpModules) {} + /// Handle the incoming request. /// /// In this phase, users can parse, validate, rate limit, perform access control and/or @@ -311,8 +328,8 @@ pub trait ProxyHttp { /// The response header is provided as an immutable reference. To modify the response header /// itself, use [`Self::upstream_response_filter()`] instead. /// - /// This filter requires the `adjust_upstream_modules` feature to be enabled. - #[cfg(feature = "adjust_upstream_modules")] + /// This filter requires the `upstream_modules` feature to be enabled. + #[cfg(feature = "upstream_modules")] async fn adjust_upstream_modules( &self, _session: &mut Session, diff --git a/pingora/Cargo.toml b/pingora/Cargo.toml index 7de8640f..4b828e90 100644 --- a/pingora/Cargo.toml +++ b/pingora/Cargo.toml @@ -128,11 +128,12 @@ time = [] ## Enable sentry for error notifications sentry = ["pingora-core/sentry"] -## Enable the `adjust_upstream_modules` filter phase on [ProxyHttp](crate::proxy::ProxyHttp) +## Enable upstream modules: the `adjust_upstream_modules` callback, the +## `upstream_modules_ctx` on Session, and `init_upstream_modules` on ProxyHttp. ## -## Allows configuring upstream modules (e.g. upstream compression) based on the -## response header before they process it. -adjust_upstream_modules = ["pingora-proxy?/adjust_upstream_modules"] +## Allows registering custom upstream modules that process response tasks +## before `upstream_compression`, and configuring them. +upstream_modules = ["pingora-proxy?/upstream_modules"] ## Enable pre-TLS connection filtering connection_filter = [ From 5e0f216a319a63d0f24c82d46afd57f2a8b41d26 Mon Sep 17 00:00:00 2001 From: Edward Wang Date: Sat, 11 Apr 2026 16:39:48 -0700 Subject: [PATCH 29/52] Return error on new conn h2 spawn stream As opposed to panicking on an error while spawning a new stream, which may happen in rare situations if a server returns GOAWAY immediately upon creating the connection. --- .bleep | 2 +- pingora-core/src/connectors/http/v2.rs | 42 +++++++++++++++++++++++--- 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/.bleep b/.bleep index da475ccc..9e4c650b 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -47de36f8b278e2c5624d50aa9d00ab98d52b35b5 \ No newline at end of file +529d46e6358f1be0fb76cd3b8d521d392e4d739b \ No newline at end of file diff --git a/pingora-core/src/connectors/http/v2.rs b/pingora-core/src/connectors/http/v2.rs index 0b70b66e..3cde4b89 100644 --- a/pingora-core/src/connectors/http/v2.rs +++ b/pingora-core/src/connectors/http/v2.rs @@ -24,7 +24,7 @@ use bytes::Bytes; use h2::client::SendRequest; use log::debug; use parking_lot::{Mutex, RwLock}; -use pingora_error::{Error, ErrorType::*, OrErr, Result}; +use pingora_error::{Error, ErrorType::*, OkOrErr, OrErr, Result}; use pingora_pool::{ConnectionMeta, ConnectionPool, PoolNode}; use std::collections::HashMap; use std::io::ErrorKind; @@ -350,10 +350,10 @@ impl Connector { settings.stream_window_size = peer_options.and_then(|o| o.h2_stream_window_size); settings.connection_window_size = peer_options.and_then(|o| o.h2_connection_window_size); let conn = handshake(stream, settings).await?; - let h2_stream = conn - .spawn_stream() - .await? - .expect("newly created connections should have at least one free stream"); + let h2_stream = conn.spawn_stream().await?.or_err( + H2Error, + "newly created connection has no free streams (server may have sent GOAWAY)", + )?; if conn.more_streams_allowed() { self.in_use_pool.insert(peer.reuse_hash(), conn); } @@ -997,4 +997,36 @@ mod tests { stream.read_response_header().await.unwrap(); assert_eq!(stream.response_header().unwrap().status, 200); } + + /// `spawn_stream()` must return `Ok(None)` when the server sends + /// GOAWAY(NO_ERROR) before any streams are opened. + #[tokio::test] + async fn test_spawn_stream_goaway_no_error_returns_none() { + let (client_io, server_io) = tokio::io::duplex(65536); + let (send_req, connection) = h2::client::handshake(client_io).await.unwrap(); + let (closed_tx, closed_rx) = watch::channel(false); + let ping_timeout = Arc::new(AtomicBool::new(false)); + let conn = ConnectionRef::new(send_req, closed_rx, ping_timeout, 0, 10, Digest::default()); + + let conn_handle = tokio::spawn(async move { + let _ = connection.await; + let _ = closed_tx.send(true); + }); + + let mut server_conn = h2::server::handshake(server_io).await.unwrap(); + server_conn.graceful_shutdown(); + let _ = server_conn.accept().await; + drop(server_conn); + + conn_handle.await.unwrap(); + + let result = conn.spawn_stream().await; + assert!( + result.is_ok(), + "expected Ok(None), got Err: {:?}", + result.as_ref().err() + ); + assert!(result.unwrap().is_none()); + assert!(conn.is_shutting_down()); + } } From 8b2fa503f9549a0ed30c860f0af35d76ef72b9ab Mon Sep 17 00:00:00 2001 From: Abhishek Aiyer Date: Tue, 14 Apr 2026 17:27:09 +0100 Subject: [PATCH 30/52] Strip H1-specific headers when downstream is a custom protocol and upstream is H2 When such a request reaches an H2 upstream, the existing version check (req.version != HTTP_2) may not fire if a malformed client sent hop-by-hop headers over H2. Add an is_custom() check so H1-specific headers are always stripped before forwarding to H2 when the downstream is a custom session. --- .bleep | 2 +- pingora-proxy/src/proxy_h2.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.bleep b/.bleep index 9e4c650b..1f263354 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -529d46e6358f1be0fb76cd3b8d521d392e4d739b \ No newline at end of file +2060d9a18432f648494798fc2a93a3785eb44e1d \ No newline at end of file diff --git a/pingora-proxy/src/proxy_h2.rs b/pingora-proxy/src/proxy_h2.rs index 2fd74f60..20c491d2 100644 --- a/pingora-proxy/src/proxy_h2.rs +++ b/pingora-proxy/src/proxy_h2.rs @@ -89,7 +89,7 @@ where { let mut req = session.req_header().clone(); - if req.version != Version::HTTP_2 { + if req.version != Version::HTTP_2 || session.downstream_session.is_custom() { /* remove H1 specific headers */ // https://github.com/hyperium/h2/blob/d3b9f1e36aadc1a7a6804e2f8e86d3fe4a244b4f/src/proto/streams/send.rs#L72 req.remove_header(&http::header::TRANSFER_ENCODING); From 927a00c9e495a07a70a8e2e96a2d3b6c67083e89 Mon Sep 17 00:00:00 2001 From: Edward Wang Date: Tue, 14 Apr 2026 17:28:01 -0700 Subject: [PATCH 31/52] Avoid hit handler finish on disabled cache This can happen when proxy tasks are enabled for downstream writes; an upstream miss handler error may end up disabling cache just as the downstream write finishes. In this and the non-proxy task case, the hit handler is dropped and no finish call should be made to begin with. --- .bleep | 2 +- pingora-proxy/src/proxy_cache.rs | 1 + pingora-proxy/src/proxy_custom.rs | 7 +++++-- pingora-proxy/src/proxy_h1.rs | 8 ++++++-- pingora-proxy/src/proxy_h2.rs | 8 ++++++-- 5 files changed, 19 insertions(+), 7 deletions(-) diff --git a/.bleep b/.bleep index 1f263354..f7667ea6 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -2060d9a18432f648494798fc2a93a3785eb44e1d \ No newline at end of file +69a651495f6dd240f0b95035ce5ae26ffad83c81 \ No newline at end of file diff --git a/pingora-proxy/src/proxy_cache.rs b/pingora-proxy/src/proxy_cache.rs index 748de963..fecf6fbd 100644 --- a/pingora-proxy/src/proxy_cache.rs +++ b/pingora-proxy/src/proxy_cache.rs @@ -487,6 +487,7 @@ where } } + // No enabled() guard: no concurrent upstream can disable cache here. if let Err(e) = session.cache.finish_hit_handler().await { warn!("Error during finish_hit_handler: {}", e); } diff --git a/pingora-proxy/src/proxy_custom.rs b/pingora-proxy/src/proxy_custom.rs index 49f430be..31cb3a52 100644 --- a/pingora-proxy/src/proxy_custom.rs +++ b/pingora-proxy/src/proxy_custom.rs @@ -528,7 +528,9 @@ where return Err(e); } } - if response_state.cached_done() { + // A storage error can disable cache between cached_done + // being set and here; see the same guard in proxy_h1.rs. + if response_state.cached_done() && session.cache.enabled() { if let Err(e) = session.cache.finish_hit_handler().await { warn!("Error during finish_hit_handler: {}", e); } @@ -552,7 +554,8 @@ where match write_result { Ok(end) => { response_state.maybe_set_cache_done(end); - if response_state.cached_done() { + // See enabled() guard comment above. + if response_state.cached_done() && session.cache.enabled() { if let Err(e) = session.cache.finish_hit_handler().await { warn!("Error during finish_hit_handler: {}", e); } diff --git a/pingora-proxy/src/proxy_h1.rs b/pingora-proxy/src/proxy_h1.rs index 8222ec67..e74309eb 100644 --- a/pingora-proxy/src/proxy_h1.rs +++ b/pingora-proxy/src/proxy_h1.rs @@ -601,7 +601,10 @@ where return Err(e); } } - if response_state.cached_done() { + // A storage error can disable cache between cached_done + // being set and here; disable() drops the enabled_ctx so + // finish_hit_handler would panic without this guard. + if response_state.cached_done() && session.cache.enabled() { if let Err(e) = session.cache.finish_hit_handler().await { warn!("Error during finish_hit_handler: {}", e); } @@ -625,7 +628,8 @@ where match write_result { Ok(end) => { response_state.maybe_set_cache_done(end); - if response_state.cached_done() { + // See enabled() guard comment above. + if response_state.cached_done() && session.cache.enabled() { if let Err(e) = session.cache.finish_hit_handler().await { warn!("Error during finish_hit_handler: {}", e); } diff --git a/pingora-proxy/src/proxy_h2.rs b/pingora-proxy/src/proxy_h2.rs index 20c491d2..e5030819 100644 --- a/pingora-proxy/src/proxy_h2.rs +++ b/pingora-proxy/src/proxy_h2.rs @@ -559,7 +559,9 @@ where return Err(e); } } - if response_state.cached_done() { + // A storage error can disable cache between cached_done + // being set and here; see the same guard in proxy_h1.rs. + if response_state.cached_done() && session.cache.enabled() { if let Err(e) = session.cache.finish_hit_handler().await { warn!("Error during finish_hit_handler: {}", e); } @@ -583,7 +585,9 @@ where match write_result { Ok(end) => { response_state.maybe_set_cache_done(end); - if response_state.cached_done() { + // See disabled() guard comment above. + // See enabled() guard comment above. + if response_state.cached_done() && session.cache.enabled() { if let Err(e) = session.cache.finish_hit_handler().await { warn!("Error during finish_hit_handler: {}", e); } From f6dadf844e7537a09695b1f2ea913eebb7be3fbf Mon Sep 17 00:00:00 2001 From: Kevin Guthrie Date: Fri, 24 Apr 2026 16:33:40 -0400 Subject: [PATCH 32/52] Syncing some mismatched internal/external changes --- pingora-core/src/connectors/l4.rs | 2 +- pingora-core/src/tls/mod.rs | 806 ------------------------------ pingora/tests/pingora_conf.yaml | 5 - 3 files changed, 1 insertion(+), 812 deletions(-) delete mode 100644 pingora-core/src/tls/mod.rs delete mode 100644 pingora/tests/pingora_conf.yaml diff --git a/pingora-core/src/connectors/l4.rs b/pingora-core/src/connectors/l4.rs index d3baaa63..d275030f 100644 --- a/pingora-core/src/connectors/l4.rs +++ b/pingora-core/src/connectors/l4.rs @@ -412,7 +412,7 @@ mod tests { let move_flag = Arc::clone(&flag); peer.options.upstream_tcp_sock_tweak_hook = Some(Arc::new(move |_| { - move_flag.fetch_xor(true, Ordering::SeqCst); + move_flag.fetch_not(Ordering::SeqCst); Ok(()) })); diff --git a/pingora-core/src/tls/mod.rs b/pingora-core/src/tls/mod.rs deleted file mode 100644 index 277b5b40..00000000 --- a/pingora-core/src/tls/mod.rs +++ /dev/null @@ -1,806 +0,0 @@ -// Copyright 2024 Cloudflare, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! This module contains a dummy TLS implementation for the scenarios where real TLS -//! implementations are unavailable. - -macro_rules! impl_display { - ($ty:ty) => { - impl std::fmt::Display for $ty { - fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { - Ok(()) - } - } - }; -} - -macro_rules! impl_deref { - ($from:ty => $to:ty) => { - impl std::ops::Deref for $from { - type Target = $to; - fn deref(&self) -> &$to { - panic!("Not implemented"); - } - } - impl std::ops::DerefMut for $from { - fn deref_mut(&mut self) -> &mut $to { - panic!("Not implemented"); - } - } - }; -} - -pub mod ssl { - use super::error::ErrorStack; - use super::x509::verify::X509VerifyParamRef; - use super::x509::{X509VerifyResult, X509}; - - /// An error returned from an ALPN selection callback. - pub struct AlpnError; - impl AlpnError { - /// Terminate the handshake with a fatal alert. - pub const ALERT_FATAL: AlpnError = Self {}; - - /// Do not select a protocol, but continue the handshake. - pub const NOACK: AlpnError = Self {}; - } - - /// A type which allows for configuration of a client-side TLS session before connection. - pub struct ConnectConfiguration; - impl_deref! {ConnectConfiguration => SslRef} - impl ConnectConfiguration { - /// Configures the use of Server Name Indication (SNI) when connecting. - pub fn set_use_server_name_indication(&mut self, _use_sni: bool) { - panic!("Not implemented"); - } - - /// Configures the use of hostname verification when connecting. - pub fn set_verify_hostname(&mut self, _verify_hostname: bool) { - panic!("Not implemented"); - } - - /// Returns an `Ssl` configured to connect to the provided domain. - pub fn into_ssl(self, _domain: &str) -> Result { - panic!("Not implemented"); - } - - /// Like `SslContextBuilder::set_verify`. - pub fn set_verify(&mut self, _mode: SslVerifyMode) { - panic!("Not implemented"); - } - - /// Like `SslContextBuilder::set_alpn_protos`. - pub fn set_alpn_protos(&mut self, _protocols: &[u8]) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Returns a mutable reference to the X509 verification configuration. - pub fn param_mut(&mut self) -> &mut X509VerifyParamRef { - panic!("Not implemented"); - } - } - - /// An SSL error. - #[derive(Debug)] - pub struct Error; - impl_display!(Error); - impl Error { - pub fn code(&self) -> ErrorCode { - panic!("Not implemented"); - } - } - - /// An error code returned from SSL functions. - #[derive(PartialEq)] - pub struct ErrorCode(i32); - impl ErrorCode { - /// An error occurred in the SSL library. - pub const SSL: ErrorCode = Self(0); - } - - /// An identifier of a session name type. - pub struct NameType; - impl NameType { - pub const HOST_NAME: NameType = Self {}; - } - - /// The state of an SSL/TLS session. - pub struct Ssl; - impl Ssl { - /// Creates a new `Ssl`. - pub fn new(_ctx: &SslContextRef) -> Result { - panic!("Not implemented"); - } - } - impl_deref! {Ssl => SslRef} - - /// A type which wraps server-side streams in a TLS session. - pub struct SslAcceptor; - impl SslAcceptor { - /// Creates a new builder configured to connect to non-legacy clients. This should - /// generally be considered a reasonable default choice. - pub fn mozilla_intermediate_v5( - _method: SslMethod, - ) -> Result { - panic!("Not implemented"); - } - } - - /// A builder for `SslAcceptor`s. - pub struct SslAcceptorBuilder; - impl SslAcceptorBuilder { - /// Consumes the builder, returning a `SslAcceptor`. - pub fn build(self) -> SslAcceptor { - panic!("Not implemented"); - } - - /// Sets the callback used by a server to select a protocol for Application Layer Protocol - /// Negotiation (ALPN). - pub fn set_alpn_select_callback(&mut self, _callback: F) - where - F: for<'a> Fn(&mut SslRef, &'a [u8]) -> Result<&'a [u8], AlpnError> - + 'static - + Sync - + Send, - { - panic!("Not implemented"); - } - - /// Loads a certificate chain from a file. - pub fn set_certificate_chain_file>( - &mut self, - _file: P, - ) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Loads the private key from a file. - pub fn set_private_key_file>( - &mut self, - _file: P, - _file_type: SslFiletype, - ) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Sets the maximum supported protocol version. - pub fn set_max_proto_version( - &mut self, - _version: Option, - ) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - } - - /// Reference to an [`SslCipher`]. - pub struct SslCipherRef; - impl SslCipherRef { - /// Returns the name of the cipher. - pub fn name(&self) -> &'static str { - panic!("Not implemented"); - } - } - - /// A type which wraps client-side streams in a TLS session. - pub struct SslConnector; - impl SslConnector { - /// Creates a new builder for TLS connections. - pub fn builder(_method: SslMethod) -> Result { - panic!("Not implemented"); - } - - /// Returns a structure allowing for configuration of a single TLS session before connection. - pub fn configure(&self) -> Result { - panic!("Not implemented"); - } - - /// Returns a shared reference to the inner raw `SslContext`. - pub fn context(&self) -> &SslContextRef { - panic!("Not implemented"); - } - } - - /// A builder for `SslConnector`s. - pub struct SslConnectorBuilder; - impl SslConnectorBuilder { - /// Consumes the builder, returning an `SslConnector`. - pub fn build(self) -> SslConnector { - panic!("Not implemented"); - } - - /// Sets the list of supported ciphers for protocols before TLSv1.3. - pub fn set_cipher_list(&mut self, _cipher_list: &str) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Sets the context’s supported signature algorithms. - pub fn set_sigalgs_list(&mut self, _sigalgs: &str) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Sets the minimum supported protocol version. - pub fn set_min_proto_version( - &mut self, - _version: Option, - ) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Sets the maximum supported protocol version. - pub fn set_max_proto_version( - &mut self, - _version: Option, - ) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Use the default locations of trusted certificates for verification. - pub fn set_default_verify_paths(&mut self) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Loads trusted root certificates from a file. - pub fn set_ca_file>( - &mut self, - _file: P, - ) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Loads a leaf certificate from a file. - pub fn set_certificate_file>( - &mut self, - _file: P, - _file_type: SslFiletype, - ) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Loads the private key from a file. - pub fn set_private_key_file>( - &mut self, - _file: P, - _file_type: SslFiletype, - ) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Sets the TLS key logging callback. - pub fn set_keylog_callback(&mut self, _callback: F) - where - F: Fn(&SslRef, &str) + 'static + Sync + Send, - { - panic!("Not implemented"); - } - } - - /// A context object for TLS streams. - pub struct SslContext; - impl SslContext { - /// Creates a new builder object for an `SslContext`. - pub fn builder(_method: SslMethod) -> Result { - panic!("Not implemented"); - } - } - impl_deref! {SslContext => SslContextRef} - - /// A builder for `SslContext`s. - pub struct SslContextBuilder; - impl SslContextBuilder { - /// Consumes the builder, returning a new `SslContext`. - pub fn build(self) -> SslContext { - panic!("Not implemented"); - } - } - - /// Reference to [`SslContext`] - pub struct SslContextRef; - - /// An identifier of the format of a certificate or key file. - pub struct SslFiletype; - impl SslFiletype { - /// The PEM format. - pub const PEM: SslFiletype = Self {}; - } - - /// A type specifying the kind of protocol an `SslContext`` will speak. - pub struct SslMethod; - impl SslMethod { - /// Support all versions of the TLS protocol. - pub fn tls() -> SslMethod { - panic!("Not implemented"); - } - } - - /// Reference to an [`Ssl`]. - pub struct SslRef; - impl SslRef { - /// Like [`SslContextBuilder::set_verify`]. - pub fn set_verify(&mut self, _mode: SslVerifyMode) { - panic!("Not implemented"); - } - - /// Returns the current cipher if the session is active. - pub fn current_cipher(&self) -> Option<&SslCipherRef> { - panic!("Not implemented"); - } - - /// Sets the host name to be sent to the server for Server Name Indication (SNI). - pub fn set_hostname(&mut self, _hostname: &str) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Returns the peer’s certificate, if present. - pub fn peer_certificate(&self) -> Option { - panic!("Not implemented"); - } - - /// Returns the certificate verification result. - pub fn verify_result(&self) -> X509VerifyResult { - panic!("Not implemented"); - } - - /// Returns a string describing the protocol version of the session. - pub fn version_str(&self) -> &'static str { - panic!("Not implemented"); - } - - /// Returns the protocol selected via Application Layer Protocol Negotiation (ALPN). - pub fn selected_alpn_protocol(&self) -> Option<&[u8]> { - panic!("Not implemented"); - } - - /// Returns the servername sent by the client via Server Name Indication (SNI). - pub fn servername(&self, _type_: NameType) -> Option<&str> { - panic!("Not implemented"); - } - } - - /// Options controlling the behavior of certificate verification. - pub struct SslVerifyMode; - impl SslVerifyMode { - /// Verifies that the peer’s certificate is trusted. - pub const PEER: Self = Self {}; - - /// Disables verification of the peer’s certificate. - pub const NONE: Self = Self {}; - } - - /// An SSL/TLS protocol version. - pub struct SslVersion; - impl SslVersion { - /// TLSv1.0 - pub const TLS1: SslVersion = Self {}; - - /// TLSv1.2 - pub const TLS1_2: SslVersion = Self {}; - - /// TLSv1.3 - pub const TLS1_3: SslVersion = Self {}; - } - - /// A standard implementation of protocol selection for Application Layer Protocol Negotiation - /// (ALPN). - pub fn select_next_proto<'a>(_server: &[u8], _client: &'a [u8]) -> Option<&'a [u8]> { - panic!("Not implemented"); - } -} - -pub mod ssl_sys { - pub const X509_V_OK: i32 = 0; - pub const X509_V_ERR_INVALID_CALL: i32 = 69; -} - -pub mod error { - use super::ssl::Error; - - /// Collection of [`Errors`] from OpenSSL. - #[derive(Debug)] - pub struct ErrorStack; - impl_display!(ErrorStack); - impl std::error::Error for ErrorStack {} - impl ErrorStack { - /// Returns the contents of the OpenSSL error stack. - pub fn get() -> ErrorStack { - panic!("Not implemented"); - } - - /// Returns the errors in the stack. - pub fn errors(&self) -> &[Error] { - panic!("Not implemented"); - } - } -} - -pub mod x509 { - use super::asn1::{Asn1IntegerRef, Asn1StringRef, Asn1TimeRef}; - use super::error::ErrorStack; - use super::hash::{DigestBytes, MessageDigest}; - use super::nid::Nid; - - /// An `X509` public key certificate. - #[derive(Debug, Clone)] - pub struct X509; - impl_deref! {X509 => X509Ref} - impl X509 { - /// Deserializes a PEM-encoded X509 structure. - pub fn from_pem(_pem: &[u8]) -> Result { - panic!("Not implemented"); - } - } - - /// A type to destructure and examine an `X509Name`. - pub struct X509NameEntries<'a> { - marker: std::marker::PhantomData<&'a ()>, - } - impl<'a> Iterator for X509NameEntries<'a> { - type Item = &'a X509NameEntryRef; - fn next(&mut self) -> Option<&'a X509NameEntryRef> { - panic!("Not implemented"); - } - } - - /// Reference to `X509NameEntry`. - pub struct X509NameEntryRef; - impl X509NameEntryRef { - pub fn data(&self) -> &Asn1StringRef { - panic!("Not implemented"); - } - } - - /// Reference to `X509Name`. - pub struct X509NameRef; - impl X509NameRef { - /// Returns the name entries by the nid. - pub fn entries_by_nid(&self, _nid: Nid) -> X509NameEntries<'_> { - panic!("Not implemented"); - } - } - - /// Reference to `X509`. - pub struct X509Ref; - impl X509Ref { - /// Returns this certificate’s subject name. - pub fn subject_name(&self) -> &X509NameRef { - panic!("Not implemented"); - } - - /// Returns a digest of the DER representation of the certificate. - pub fn digest(&self, _hash_type: MessageDigest) -> Result { - panic!("Not implemented"); - } - - /// Returns the certificate’s Not After validity period. - pub fn not_after(&self) -> &Asn1TimeRef { - panic!("Not implemented"); - } - - /// Returns this certificate’s serial number. - pub fn serial_number(&self) -> &Asn1IntegerRef { - panic!("Not implemented"); - } - } - - /// The result of peer certificate verification. - pub struct X509VerifyResult; - impl X509VerifyResult { - /// Return the integer representation of an `X509VerifyResult`. - pub fn as_raw(&self) -> i32 { - panic!("Not implemented"); - } - } - - pub mod store { - use super::super::error::ErrorStack; - use super::X509; - - /// A builder type used to construct an `X509Store`. - pub struct X509StoreBuilder; - impl X509StoreBuilder { - /// Returns a builder for a certificate store.. - pub fn new() -> Result { - panic!("Not implemented"); - } - - /// Constructs the `X509Store`. - pub fn build(self) -> X509Store { - panic!("Not implemented"); - } - - /// Adds a certificate to the certificate store. - pub fn add_cert(&mut self, _cert: X509) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - } - - /// A certificate store to hold trusted X509 certificates. - pub struct X509Store; - impl_deref! {X509Store => X509StoreRef} - - /// Reference to an `X509Store`. - pub struct X509StoreRef; - } - - pub mod verify { - /// Reference to `X509VerifyParam`. - pub struct X509VerifyParamRef; - } -} - -pub mod nid { - /// A numerical identifier for an OpenSSL object. - pub struct Nid; - impl Nid { - pub const COMMONNAME: Nid = Self {}; - pub const ORGANIZATIONNAME: Nid = Self {}; - pub const ORGANIZATIONALUNITNAME: Nid = Self {}; - } -} - -pub mod pkey { - use super::error::ErrorStack; - - /// A public or private key. - #[derive(Clone)] - pub struct PKey { - marker: std::marker::PhantomData, - } - impl std::ops::Deref for PKey { - type Target = PKeyRef; - fn deref(&self) -> &PKeyRef { - panic!("Not implemented"); - } - } - impl std::ops::DerefMut for PKey { - fn deref_mut(&mut self) -> &mut PKeyRef { - panic!("Not implemented"); - } - } - impl PKey { - pub fn private_key_from_pem(_pem: &[u8]) -> Result, ErrorStack> { - panic!("Not implemented"); - } - } - - /// Reference to `PKey`. - pub struct PKeyRef { - marker: std::marker::PhantomData, - } - - /// A tag type indicating that a key has private components. - #[derive(Clone)] - pub enum Private {} - unsafe impl HasPrivate for Private {} - - /// A trait indicating that a key has private components. - pub unsafe trait HasPrivate {} -} - -pub mod hash { - /// A message digest algorithm. - pub struct MessageDigest; - impl MessageDigest { - pub fn sha256() -> MessageDigest { - panic!("Not implemented"); - } - } - - /// The resulting bytes of a digest. - pub struct DigestBytes; - impl AsRef<[u8]> for DigestBytes { - fn as_ref(&self) -> &[u8] { - panic!("Not implemented"); - } - } -} - -pub mod asn1 { - use super::bn::BigNum; - use super::error::ErrorStack; - - /// A reference to an `Asn1Integer`. - pub struct Asn1IntegerRef; - impl Asn1IntegerRef { - /// Converts the integer to a `BigNum`. - pub fn to_bn(&self) -> Result { - panic!("Not implemented"); - } - } - - /// A reference to an `Asn1String`. - pub struct Asn1StringRef; - impl Asn1StringRef { - pub fn as_utf8(&self) -> Result<&str, ErrorStack> { - panic!("Not implemented"); - } - } - - /// Reference to an `Asn1Time` - pub struct Asn1TimeRef; - impl_display! {Asn1TimeRef} -} - -pub mod bn { - use super::error::ErrorStack; - - /// Dynamically sized large number implementation - pub struct BigNum; - impl BigNum { - /// Returns a hexadecimal string representation of `self`. - pub fn to_hex_str(&self) -> Result<&str, ErrorStack> { - panic!("Not implemented"); - } - } -} - -pub mod ext { - use super::error::ErrorStack; - use super::pkey::{HasPrivate, PKeyRef}; - use super::ssl::{Ssl, SslAcceptor, SslRef}; - use super::x509::store::X509StoreRef; - use super::x509::verify::X509VerifyParamRef; - use super::x509::X509Ref; - - /// Add name as an additional reference identifier that can match the peer's certificate - pub fn add_host(_verify_param: &mut X509VerifyParamRef, _host: &str) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Set the verify cert store of `_ssl` - pub fn ssl_set_verify_cert_store( - _ssl: &mut SslRef, - _cert_store: &X509StoreRef, - ) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Load the certificate into `_ssl` - pub fn ssl_use_certificate(_ssl: &mut SslRef, _cert: &X509Ref) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Load the private key into `_ssl` - pub fn ssl_use_private_key(_ssl: &mut SslRef, _key: &PKeyRef) -> Result<(), ErrorStack> - where - T: HasPrivate, - { - panic!("Not implemented"); - } - - /// Clear the error stack - pub fn clear_error_stack() {} - - /// Create a new [Ssl] from &[SslAcceptor] - pub fn ssl_from_acceptor(_acceptor: &SslAcceptor) -> Result { - panic!("Not implemented"); - } - - /// Suspend the TLS handshake when a certificate is needed. - pub fn suspend_when_need_ssl_cert(_ssl: &mut SslRef) { - panic!("Not implemented"); - } - - /// Unblock a TLS handshake after the certificate is set. - pub fn unblock_ssl_cert(_ssl: &mut SslRef) { - panic!("Not implemented"); - } - - /// Whether the TLS error is SSL_ERROR_WANT_X509_LOOKUP - pub fn is_suspended_for_cert(_error: &super::ssl::Error) -> bool { - panic!("Not implemented"); - } - - /// Add the certificate into the cert chain of `_ssl` - pub fn ssl_add_chain_cert(_ssl: &mut SslRef, _cert: &X509Ref) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Set renegotiation - pub fn ssl_set_renegotiate_mode_freely(_ssl: &mut SslRef) {} - - /// Set the curves/groups of `_ssl` - pub fn ssl_set_groups_list(_ssl: &mut SslRef, _groups: &str) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Sets whether a second keyshare to be sent in client hello when PQ is used. - pub fn ssl_use_second_key_share(_ssl: &mut SslRef, _enabled: bool) {} - - /// Get a mutable SslRef ouf of SslRef, which is a missing functionality even when holding &mut SslStream - /// # Safety - pub unsafe fn ssl_mut(_ssl: &SslRef) -> &mut SslRef { - panic!("Not implemented"); - } -} - -pub mod tokio_ssl { - use std::pin::Pin; - use std::task::{Context, Poll}; - use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; - - use super::error::ErrorStack; - use super::ssl::{Error, Ssl, SslRef}; - - /// A TLS session over a stream. - #[derive(Debug)] - pub struct SslStream { - marker: std::marker::PhantomData, - } - impl SslStream { - /// Creates a new `SslStream`. - pub fn new(_ssl: Ssl, _stream: S) -> Result { - panic!("Not implemented"); - } - - /// Initiates a client-side TLS handshake. - pub async fn connect(self: Pin<&mut Self>) -> Result<(), Error> { - panic!("Not implemented"); - } - - /// Initiates a server-side TLS handshake. - pub async fn accept(self: Pin<&mut Self>) -> Result<(), Error> { - panic!("Not implemented"); - } - - /// Returns a shared reference to the `Ssl` object associated with this stream. - pub fn ssl(&self) -> &SslRef { - panic!("Not implemented"); - } - - /// Returns a shared reference to the underlying stream. - pub fn get_ref(&self) -> &S { - panic!("Not implemented"); - } - - /// Returns a mutable reference to the underlying stream. - pub fn get_mut(&mut self) -> &mut S { - panic!("Not implemented"); - } - } - impl AsyncRead for SslStream - where - S: AsyncRead + AsyncWrite, - { - fn poll_read( - self: Pin<&mut Self>, - _ctx: &mut Context<'_>, - _buf: &mut ReadBuf<'_>, - ) -> Poll> { - panic!("Not implemented"); - } - } - impl AsyncWrite for SslStream - where - S: AsyncRead + AsyncWrite, - { - fn poll_write( - self: Pin<&mut Self>, - _ctx: &mut Context<'_>, - _buf: &[u8], - ) -> Poll> { - panic!("Not implemented"); - } - - fn poll_flush(self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll> { - panic!("Not implemented"); - } - - fn poll_shutdown( - self: Pin<&mut Self>, - _ctx: &mut Context<'_>, - ) -> Poll> { - panic!("Not implemented"); - } - } -} diff --git a/pingora/tests/pingora_conf.yaml b/pingora/tests/pingora_conf.yaml deleted file mode 100644 index c21ae15a..00000000 --- a/pingora/tests/pingora_conf.yaml +++ /dev/null @@ -1,5 +0,0 @@ ---- -version: 1 -client_bind_to_ipv4: - - 127.0.0.2 -ca_file: tests/keys/server.crt \ No newline at end of file From 3a95c50aa11239e1ca7fbb46f701568b9576c92c Mon Sep 17 00:00:00 2001 From: Kevin Guthrie Date: Fri, 24 Apr 2026 16:58:39 -0400 Subject: [PATCH 33/52] RUSTSEC-2026-0098 and RUSTSEC-2026-0099 fixes Bump dev-deps to pull in rustls-webpki 0.103.12. --- .bleep | 2 +- pingora-core/Cargo.toml | 9 +++-- pingora-core/src/connectors/tls/rustls/mod.rs | 3 ++ pingora-core/src/listeners/tls/rustls/mod.rs | 3 ++ pingora-core/tests/test_basic.rs | 5 ++- pingora-proxy/Cargo.toml | 11 +++--- pingora-proxy/tests/test_basic.rs | 34 +++++++++++-------- pingora-proxy/tests/test_upstream.rs | 31 +++++++++-------- pingora-rustls/Cargo.toml | 2 +- pingora-rustls/src/lib.rs | 14 ++++++-- pingora/Cargo.toml | 8 +++-- 11 files changed, 77 insertions(+), 45 deletions(-) diff --git a/.bleep b/.bleep index f7667ea6..03cba5ad 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -69a651495f6dd240f0b95035ce5ae26ffad83c81 \ No newline at end of file +77544580ab2cf44cc649ce2d77b180ef6c0aaa40 \ No newline at end of file diff --git a/pingora-core/Cargo.toml b/pingora-core/Cargo.toml index 12ff7a23..82dcb2f9 100644 --- a/pingora-core/Cargo.toml +++ b/pingora-core/Cargo.toml @@ -85,15 +85,18 @@ windows-sys = { version = "0.59.0", features = ["Win32_Networking_WinSock"] } h2 = { workspace = true, features = ["unstable"] } tokio-stream = { version = "0.1", features = ["full"] } env_logger = "0.11" -reqwest = { version = "0.11", features = [ +reqwest = { version = "0.12", features = [ "rustls-tls", + "http2", ], default-features = false } -hyper = "0.14" +hyper = { version = "1", features = ["client", "http1", "http2"] } +hyper-util = { version = "0.1", features = ["client-legacy", "http1", "http2"] } +http-body-util = "0.1" rstest = "0.23.0" rustls = "0.23" [target.'cfg(unix)'.dev-dependencies] -hyperlocal = "0.8" +hyperlocal = "0.9" jemallocator = "0.5" [features] diff --git a/pingora-core/src/connectors/tls/rustls/mod.rs b/pingora-core/src/connectors/tls/rustls/mod.rs index 58ea4085..23e3a307 100644 --- a/pingora-core/src/connectors/tls/rustls/mod.rs +++ b/pingora-core/src/connectors/tls/rustls/mod.rs @@ -61,6 +61,9 @@ impl TlsConnector { where Self: Sized, { + // rustls 0.23+ requires an explicit CryptoProvider. + pingora_rustls::install_default_crypto_provider(); + // NOTE: Rustls only supports TLS 1.2 & 1.3 // TODO: currently using Rustls defaults diff --git a/pingora-core/src/listeners/tls/rustls/mod.rs b/pingora-core/src/listeners/tls/rustls/mod.rs index 0ca94d51..e7376fc0 100644 --- a/pingora-core/src/listeners/tls/rustls/mod.rs +++ b/pingora-core/src/listeners/tls/rustls/mod.rs @@ -48,6 +48,9 @@ impl TlsSettings { /// /// Todo: Return a result instead of panicking XD pub fn build(self) -> Acceptor { + // rustls 0.23+ requires an explicit CryptoProvider. + pingora_rustls::install_default_crypto_provider(); + let Ok(Some((certs, key))) = load_certs_and_key_files(&self.cert_path, &self.key_path) else { panic!( diff --git a/pingora-core/tests/test_basic.rs b/pingora-core/tests/test_basic.rs index 445d75b9..ae6ee810 100644 --- a/pingora-core/tests/test_basic.rs +++ b/pingora-core/tests/test_basic.rs @@ -14,6 +14,8 @@ mod utils; +#[cfg(all(unix, feature = "any_tls"))] +use hyper_util::client::legacy::Client; #[cfg(all(unix, feature = "any_tls"))] use hyperlocal::{UnixClientExt, Uri}; @@ -55,7 +57,8 @@ async fn test_https_http2() { async fn test_uds() { utils::init(); let url = Uri::new("/tmp/echo.sock", "/").into(); - let client = hyper::Client::unix(); + let client: Client> = + Client::unix(); let res = client.get(url).await.unwrap(); assert_eq!(res.status(), reqwest::StatusCode::OK); diff --git a/pingora-proxy/Cargo.toml b/pingora-proxy/Cargo.toml index c2e579c9..e1cc1cbb 100644 --- a/pingora-proxy/Cargo.toml +++ b/pingora-proxy/Cargo.toml @@ -36,15 +36,18 @@ regex = "1" rand = "0.8" [dev-dependencies] -reqwest = { version = "0.11", features = [ +reqwest = { version = "0.12", features = [ "gzip", "rustls-tls", + "http2", ], default-features = false } httparse = { workspace = true } tokio-test = "0.4" env_logger = "0.11" -hyper = "0.14" -tokio-tungstenite = "0.20.1" +hyper = { version = "1", features = ["client", "http1", "http2"] } +hyper-util = { version = "0.1", features = ["client-legacy", "http1", "http2"] } +http-body-util = "0.1" +tokio-tungstenite = "0.26" pingora-limits = { version = "0.8.0", path = "../pingora-limits" } pingora-load-balancing = { version = "0.8.0", path = "../pingora-load-balancing", default-features=false } pingora-prometheus = { version = "0.8.0", path = "../pingora-prometheus" } @@ -55,7 +58,7 @@ serde_json = "1.0" serde_yaml = "0.9" [target.'cfg(unix)'.dev-dependencies] -hyperlocal = "0.8" +hyperlocal = "0.9" [features] default = [] diff --git a/pingora-proxy/tests/test_basic.rs b/pingora-proxy/tests/test_basic.rs index 77303fc3..cc48cb42 100644 --- a/pingora-proxy/tests/test_basic.rs +++ b/pingora-proxy/tests/test_basic.rs @@ -17,7 +17,8 @@ mod utils; use bytes::Bytes; use h2::client; use http::Request; -use hyper::{body::HttpBody, header::HeaderValue, Body, Client}; +use http_body_util::BodyExt; +use hyper_util::client::legacy::Client; #[cfg(unix)] use hyperlocal::{UnixClientExt, Uri}; use reqwest::{header, StatusCode}; @@ -161,21 +162,21 @@ async fn test_h2_to_h2() { async fn test_h2c_to_h2c() { init(); - let client = hyper::client::Client::builder() + let client = hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()) .http2_only(true) - .build_http(); + .build_http::>(); - let mut req = hyper::Request::builder() + let mut req = http::Request::builder() .uri("http://127.0.0.1:6146") - .body(Body::empty()) + .body(http_body_util::Empty::::new()) .unwrap(); req.headers_mut() - .insert("x-h2", HeaderValue::from_bytes(b"true").unwrap()); + .insert("x-h2", http::HeaderValue::from_bytes(b"true").unwrap()); let res = client.request(req).await.unwrap(); assert_eq!(res.status(), reqwest::StatusCode::OK); assert_eq!(res.version(), reqwest::Version::HTTP_2); - let body = res.into_body().data().await.unwrap().unwrap(); + let body = res.into_body().collect().await.unwrap().to_bytes(); assert_eq!(body.as_ref(), b"Hello World!\n"); } @@ -183,21 +184,21 @@ async fn test_h2c_to_h2c() { async fn test_h1_on_h2c_port() { init(); - let client = hyper::client::Client::builder() + let client = hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()) .http2_only(false) - .build_http(); + .build_http::>(); - let mut req = hyper::Request::builder() + let mut req = http::Request::builder() .uri("http://127.0.0.1:6146") - .body(Body::empty()) + .body(http_body_util::Empty::::new()) .unwrap(); req.headers_mut() - .insert("x-h2", HeaderValue::from_bytes(b"true").unwrap()); + .insert("x-h2", http::HeaderValue::from_bytes(b"true").unwrap()); let res = client.request(req).await.unwrap(); assert_eq!(res.status(), reqwest::StatusCode::OK); assert_eq!(res.version(), reqwest::Version::HTTP_11); - let body = res.into_body().data().await.unwrap().unwrap(); + let body = res.into_body().collect().await.unwrap().to_bytes(); assert_eq!(body.as_ref(), b"Hello World!\n"); } @@ -303,7 +304,7 @@ async fn test_h2_head() { async fn test_simple_proxy_uds() { init(); let url = Uri::new("/tmp/pingora_proxy.sock", "/").into(); - let client = Client::unix(); + let client: Client> = Client::unix(); let res = client.get(url).await.unwrap(); @@ -324,7 +325,10 @@ async fn test_simple_proxy_uds() { assert_eq!(sockaddr.ip().to_string(), "127.0.0.2"); assert!(is_specified_port(sockaddr.port())); - let body = hyper::body::to_bytes(body).await.unwrap(); + let body = http_body_util::BodyExt::collect(body) + .await + .unwrap() + .to_bytes(); assert_eq!(body.as_ref(), b"Hello World!\n"); } diff --git a/pingora-proxy/tests/test_upstream.rs b/pingora-proxy/tests/test_upstream.rs index ff6453d4..7e85c2f8 100644 --- a/pingora-proxy/tests/test_upstream.rs +++ b/pingora-proxy/tests/test_upstream.rs @@ -17,9 +17,11 @@ mod utils; use utils::server_utils::init; use utils::websocket::{WS_ECHO, WS_ECHO_RAW}; +use bytes::Bytes; use futures::{SinkExt, StreamExt}; +use http::header::{HeaderName, HeaderValue}; +use http_body_util::BodyExt; use pingora_http::ResponseHeader; -use reqwest::header::{HeaderName, HeaderValue}; use reqwest::{StatusCode, Version}; use std::time::{Duration, Instant}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -181,7 +183,7 @@ async fn test_ws_server_ends_conn() { ws_stream.close(None).await.unwrap(); let msg = ws_stream.next().await.unwrap().unwrap(); // assert echo - assert_eq!("test", msg.into_text().unwrap()); + assert_eq!(msg.into_text().unwrap(), "test"); let msg = ws_stream.next().await.unwrap().unwrap(); // assert graceful close assert!(matches!(msg, Message::Close(None))); @@ -363,22 +365,22 @@ async fn test_upgrade_body_after_101() { #[tokio::test] async fn test_download_timeout() { init(); - use hyper::body::HttpBody; use tokio::time::sleep; - let client = hyper::Client::new(); - let uri: hyper::Uri = "http://127.0.0.1:6147/download_large/".parse().unwrap(); - let req = hyper::Request::builder() + let client = hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()) + .build_http::>(); + let uri: http::Uri = "http://127.0.0.1:6147/download_large/".parse().unwrap(); + let req = http::Request::builder() .uri(uri) .header("x-write-timeout", "1") - .body(hyper::Body::empty()) + .body(http_body_util::Empty::::new()) .unwrap(); let mut res = client.request(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); let mut err = false; sleep(Duration::from_secs(2)).await; - while let Some(chunk) = res.body_mut().data().await { + while let Some(chunk) = res.body_mut().frame().await { if chunk.is_err() { err = true; } @@ -389,28 +391,27 @@ async fn test_download_timeout() { #[tokio::test] async fn test_download_timeout_min_rate() { init(); - use hyper::body::HttpBody; use tokio::time::sleep; - let client = hyper::Client::new(); - let uri: hyper::Uri = "http://127.0.0.1:6147/download/".parse().unwrap(); - let req = hyper::Request::builder() + let client = hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()) + .build_http::>(); + let uri: http::Uri = "http://127.0.0.1:6147/download/".parse().unwrap(); + let req = http::Request::builder() .uri(uri) .header("x-write-timeout", "1") .header("x-min-rate", "10000") - .body(hyper::Body::empty()) + .body(http_body_util::Empty::::new()) .unwrap(); let mut res = client.request(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); let mut err = false; sleep(Duration::from_secs(2)).await; - while let Some(chunk) = res.body_mut().data().await { + while let Some(chunk) = res.body_mut().frame().await { if chunk.is_err() { err = true; } } - // no error as write timeout is overridden by min rate assert!(!err); } diff --git a/pingora-rustls/Cargo.toml b/pingora-rustls/Cargo.toml index efa377bf..51cd00ff 100644 --- a/pingora-rustls/Cargo.toml +++ b/pingora-rustls/Cargo.toml @@ -18,7 +18,7 @@ path = "src/lib.rs" log = "0.4.21" pingora-error = { version = "0.8.0", path = "../pingora-error"} ring = "0.17.12" -rustls = "0.23.12" +rustls = { version = "0.23.12", features = ["ring"] } rustls-native-certs = "0.7.1" rustls-pemfile = "2.1.2" rustls-pki-types = "1.7.0" diff --git a/pingora-rustls/src/lib.rs b/pingora-rustls/src/lib.rs index 097a8da5..deb0c88b 100644 --- a/pingora-rustls/src/lib.rs +++ b/pingora-rustls/src/lib.rs @@ -28,9 +28,19 @@ use pingora_error::{Error, ErrorType, OrErr, Result}; pub use rustls::server::danger::{ClientCertVerified, ClientCertVerifier}; pub use rustls::server::{ClientCertVerifierBuilder, WebPkiClientVerifier}; pub use rustls::{ - client::WebPkiServerVerifier, version, CertificateError, ClientConfig, DigitallySignedStruct, - Error as RusTlsError, KeyLogFile, RootCertStore, ServerConfig, SignatureScheme, Stream, + client::WebPkiServerVerifier, crypto::CryptoProvider, version, CertificateError, ClientConfig, + DigitallySignedStruct, Error as RusTlsError, KeyLogFile, RootCertStore, ServerConfig, + SignatureScheme, Stream, }; + +/// Install the default `ring` CryptoProvider for rustls. +/// +/// rustls 0.23+ requires an explicit provider. This function installs `ring` +/// as the process-level default. Safe to call multiple times — subsequent +/// calls are no-ops. +pub fn install_default_crypto_provider() { + let _ = CryptoProvider::install_default(rustls::crypto::ring::default_provider()); +} pub use rustls_native_certs::load_native_certs; use rustls_pemfile::Item; pub use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime}; diff --git a/pingora/Cargo.toml b/pingora/Cargo.toml index 4b828e90..8e30130a 100644 --- a/pingora/Cargo.toml +++ b/pingora/Cargo.toml @@ -38,8 +38,10 @@ document-features = { version = "0.2.10", optional = true } clap = { version = "4.5", features = ["derive"] } tokio = { workspace = true, features = ["rt-multi-thread", "signal"] } env_logger = "0.11" -reqwest = { version = "0.11", features = ["rustls"], default-features = false } -hyper = "0.14" +reqwest = { version = "0.12", features = ["rustls-tls", "http2"], default-features = false } +hyper = { version = "1", features = ["client", "http1", "http2"] } +hyper-util = { version = "0.1", features = ["client-legacy", "http1", "http2"] } +http-body-util = "0.1" async-trait = { workspace = true } http = { workspace = true } log = { workspace = true } @@ -50,7 +52,7 @@ bytes = { workspace = true } regex = "1" [target.'cfg(unix)'.dev-dependencies] -hyperlocal = "0.8" +hyperlocal = "0.9" jemallocator = "0.5" [features] From 1476e7a5eb6c2cfca2fffd5a82682c1a5262ac17 Mon Sep 17 00:00:00 2001 From: Matthew Gumport Date: Tue, 14 Apr 2026 13:16:26 -0700 Subject: [PATCH 34/52] expose pipe receiver in subrequest state The receiver drops when the coordinator exits the pipe loop, breaking the channel before the writer finishes its cache-write lifecycle. Return it in the state for callers to drain alongside the task handle. --- .bleep | 2 +- pingora-proxy/src/subrequest/pipe.rs | 34 +++++++++++++++++++--------- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/.bleep b/.bleep index 03cba5ad..8639e006 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -77544580ab2cf44cc649ce2d77b180ef6c0aaa40 \ No newline at end of file +3a69c94066ae8bab489c56856c576bacfd341d6b \ No newline at end of file diff --git a/pingora-proxy/src/subrequest/pipe.rs b/pingora-proxy/src/subrequest/pipe.rs index 279a89de..7845d4dc 100644 --- a/pingora-proxy/src/subrequest/pipe.rs +++ b/pingora-proxy/src/subrequest/pipe.rs @@ -53,16 +53,25 @@ pub struct PipeSubrequestState { /// The spawned subrequest task handle. Always set after spawn. Caller is /// responsible for awaiting/inspecting state. pub join_handle: Option>, + /// The receiving half of the pipe channel. When the coordinator exits + /// `pipe_subrequest` before the subrequest task finishes writing, this + /// receiver must be kept alive and drained alongside the join handle; + /// otherwise dropping it breaks the pipe and prevents the writer from + /// completing its cache-write lifecycle. + pub pipe_rx: Option>, } impl PipeSubrequestState { /// Creates a snapshot for error reporting, excluding the join handle. + /// Moves `pipe_rx` into the snapshot so the receiver stays alive through + /// the error path and is not dropped when `self` is cleaned up. /// Used by [`map_pipe_err`] to capture state at the point of failure. - pub fn snapshot_for_error(&self) -> Self { + pub fn snapshot_for_error(&mut self) -> Self { PipeSubrequestState { saved_body: self.saved_body.clone(), header_received: self.header_received, join_handle: None, + pipe_rx: self.pipe_rx.take(), } } } @@ -91,7 +100,7 @@ impl PipeSubrequestError { fn map_pipe_err>>( result: Result, from_subreq: bool, - state: &PipeSubrequestState, + state: &mut PipeSubrequestState, ) -> Result { result.map_err(|e| PipeSubrequestError::new(e, from_subreq, state.snapshot_for_error())) } @@ -212,7 +221,10 @@ where }); state.join_handle = Some(join_handle); let tx = subrequest_handle.tx; - let mut rx = subrequest_handle.rx; + // Move rx into state immediately so it survives all exit paths (early `?` + // returns, errors, and the normal success path). The select loop borrows it + // back via `state.pipe_rx.as_mut().expect(...)`. + state.pipe_rx = Some(subrequest_handle.rx); let mut wants_body = false; let mut wants_body_rx_err = false; @@ -229,7 +241,7 @@ where .or_err(InternalError, "try_reserve() body pipe for subrequest"); tokio::select! { - task = rx.recv(), if !response_state.upstream_done() => { + task = state.pipe_rx.as_mut().expect("pipe_rx always set after spawn").recv(), if !response_state.upstream_done() => { debug!("upstream event: {:?}", task); if let Some(t) = task { // Did the subrequest get headers? @@ -239,17 +251,17 @@ where // pull as many tasks as we can const TASK_BUFFER_SIZE: usize = 4; let mut tasks = Vec::with_capacity(TASK_BUFFER_SIZE); - let task = map_pipe_err(task_filter(t), false, &state)?; + let task = map_pipe_err(task_filter(t), false, &mut state)?; if let Some(filtered) = task { tasks.push(filtered); } // tokio::task::unconstrained because now_or_never may yield None when the future is ready - while let Some(maybe_task) = tokio::task::unconstrained(rx.recv()).now_or_never() { + while let Some(maybe_task) = tokio::task::unconstrained(state.pipe_rx.as_mut().expect("pipe_rx always set after spawn").recv()).now_or_never() { if let Some(t) = maybe_task { if matches!(&t, HttpTask::Header(..)) { state.header_received = true; } - let task = map_pipe_err(task_filter(t), false, &state)?; + let task = map_pipe_err(task_filter(t), false, &mut state)?; if let Some(filtered) = task { tasks.push(filtered); } @@ -259,7 +271,7 @@ where } // FIXME: if one of these tasks is Failed(e), the session will return that // error; in this case, the error is actually from the subreq - let response_done = map_pipe_err(session.write_response_tasks(tasks).await, false, &state)?; + let response_done = map_pipe_err(session.write_response_tasks(tasks).await, false, &mut state)?; // NOTE: technically it is the downstream whose response state has finished here // we consider the subrequest's work done however @@ -309,7 +321,7 @@ where // this is the first subrequest // send the body debug!("downstream event: main body for subrequest"); - let body = map_pipe_err(body.map_err(|e| e.into_down()), false, &state)?; + let body = map_pipe_err(body.map_err(|e| e.into_down()), false, &mut state)?; // If the request is websocket, `None` body means the request is closed. // Set the response to be done as well so that the request completes normally. @@ -325,7 +337,7 @@ where state.saved_body.as_mut(), send_permit.expect("checked is_ok()"), ) - .await, false, &state)?; + .await, false, &mut state)?; downstream_state.maybe_finished(request_done); @@ -346,7 +358,7 @@ where is_body_done, None, send_permit.expect("checked is_ok()"), - ), false, &state)?; + ), false, &mut state)?; downstream_state.maybe_finished(request_done); }, From 1f83d3c8cecb025cede75b3f045f79b73cdc3309 Mon Sep 17 00:00:00 2001 From: Ian Crutcher Date: Mon, 13 Apr 2026 14:32:56 -0500 Subject: [PATCH 35/52] Changing type of PeerOptions curve to Cow to allow for dynamically determined curves --- .bleep | 2 +- .gitignore | 3 ++- .../src/connectors/tls/boringssl_openssl/mod.rs | 2 +- pingora-core/src/upstreams/peer.rs | 15 ++++++++------- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/.bleep b/.bleep index 8639e006..fb32c7f8 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -3a69c94066ae8bab489c56856c576bacfd341d6b \ No newline at end of file +5986e41a0552d4d071a8c546bf17e46ec8b7d59d \ No newline at end of file diff --git a/.gitignore b/.gitignore index abc8cf51..a8fa88f9 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ dhat-heap.json .vscode .idea .cover -bleeper.user.toml \ No newline at end of file +bleeper.user.toml +.DS_Store diff --git a/pingora-core/src/connectors/tls/boringssl_openssl/mod.rs b/pingora-core/src/connectors/tls/boringssl_openssl/mod.rs index 9bb3a5a6..8585deef 100644 --- a/pingora-core/src/connectors/tls/boringssl_openssl/mod.rs +++ b/pingora-core/src/connectors/tls/boringssl_openssl/mod.rs @@ -193,7 +193,7 @@ where } } - if let Some(curve) = peer.get_peer_options().and_then(|o| o.curves) { + if let Some(curve) = peer.get_peer_options().and_then(|o| o.curves.as_deref()) { ssl_set_groups_list(&mut ssl_conf, curve).or_err(InternalError, "invalid curves")?; } diff --git a/pingora-core/src/upstreams/peer.rs b/pingora-core/src/upstreams/peer.rs index 78c6dbcc..b5ec9d76 100644 --- a/pingora-core/src/upstreams/peer.rs +++ b/pingora-core/src/upstreams/peer.rs @@ -33,6 +33,7 @@ use pingora_error::{ }; #[cfg(feature = "s2n")] use pingora_s2n::S2NPolicy; +use std::borrow::Cow; use std::collections::BTreeMap; use std::fmt::{Display, Formatter, Result as FmtResult}; use std::hash::{Hash, Hasher}; @@ -447,16 +448,16 @@ pub struct PeerOptions { /// It exists primarily for compatibility with legacy servers that send malformed headers. pub allow_h1_response_invalid_content_length: bool, pub extra_proxy_headers: BTreeMap>, - // The list of curve the tls connection should advertise - // if `None`, the default curves will be used - pub curves: Option<&'static str>, - // see ssl_use_second_key_share + /// The list of curves the tls connection should advertise + /// if `None`, the default curves will be used + pub curves: Option>, + /// see ssl_use_second_key_share pub second_keyshare: bool, - // whether to enable TCP fast open + /// whether to enable TCP fast open pub tcp_fast_open: bool, - // use Arc because Clone is required but not allowed in trait object + /// use Arc because Clone is required but not allowed in trait object pub tracer: Option, - // A custom L4 connector to use to establish new L4 connections + /// A custom L4 connector to use to establish new L4 connections pub custom_l4: Option>, #[derivative(Debug = "ignore")] pub upstream_tcp_sock_tweak_hook: From 6c523ee7538f2c5b127cce6a797ee92c38e2bb89 Mon Sep 17 00:00:00 2001 From: Andrew Hauck Date: Tue, 21 Apr 2026 10:18:38 -0700 Subject: [PATCH 36/52] Add support for fractional delta seconds that are floored (optional RFC 9111 compliance) --- .bleep | 2 +- pingora-cache/src/cache_control.rs | 240 ++++++++++++++++++++++++++++- 2 files changed, 239 insertions(+), 3 deletions(-) diff --git a/.bleep b/.bleep index fb32c7f8..0d1fd142 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -5986e41a0552d4d071a8c546bf17e46ec8b7d59d \ No newline at end of file +26f2c0e18228072ec5d1e699d78e79cba56ae9b5 \ No newline at end of file diff --git a/pingora-cache/src/cache_control.rs b/pingora-cache/src/cache_control.rs index 98af7fbb..f8612080 100644 --- a/pingora-cache/src/cache_control.rs +++ b/pingora-cache/src/cache_control.rs @@ -92,6 +92,51 @@ impl DirectiveValue { } } } + + /// Parse the [DirectiveValue] as delta seconds, permitting a fractional component. + /// + /// Values with a fractional part are rounded down (floored) to the nearest + /// non-negative integer: e.g. `1.9` -> `1`, `0.5` -> `0`. This is useful + /// for compatibility with upstreams that emit fractional ttls (which + /// RFC 9111 strict integer parsing rejects). + /// + /// Integer parsing is attempted first, so strictly-numeric values (including + /// overflow-capped values and quoted integers) behave identically to + /// [Self::parse_as_delta_seconds]. Negative, non-finite, or non-numeric + /// values still return an error. + /// + /// `"`s are ignored. The value is capped to [DELTA_SECONDS_OVERFLOW_VALUE]. + pub fn parse_as_delta_seconds_floor(&self) -> Result { + // UTF-8 validate once; on non-UTF8 input, propagate the same error as + // [Self::parse_as_delta_seconds]. + let s = self.parse_as_str()?; + match s.parse::() { + Ok(value) => Ok(value), + Err(e) if e.kind() == &IntErrorKind::PosOverflow => Ok(DELTA_SECONDS_OVERFLOW_VALUE), + Err(int_err) => { + // Fall back to parsing as a non-negative finite float and floor. + // On any failure, return an error equivalent to the strict + // [Self::parse_as_delta_seconds] u32-parse error. + match s.parse::() { + Ok(f) if f.is_finite() && f >= 0.0 => { + if f >= DELTA_SECONDS_OVERFLOW_VALUE as f64 { + Ok(DELTA_SECONDS_OVERFLOW_VALUE) + } else { + // Safe cast: `f` is finite, non-negative, and strictly + // less than `DELTA_SECONDS_OVERFLOW_VALUE` (i32::MAX), + // which fits in `u32` after flooring. + Ok(f.floor() as u32) + } + } + _ => Error::e_because( + ErrorType::InternalError, + "could not parse value as u32", + int_err, + ), + } + } + } + } } /// An ordered map to store cache control key value pairs. @@ -102,6 +147,15 @@ pub type DirectiveMap = IndexMap>; pub struct CacheControl { /// The parsed directives pub directives: DirectiveMap, + /// When set, delta-seconds directives (`max-age`, `s-maxage`, + /// `stale-while-revalidate`, `stale-if-error`) accept fractional values + /// and round them down to the nearest non-negative integer. + /// + /// Defaults to `false`, matching RFC 9111 strict integer parsing. Enable + /// via [CacheControl::with_float_seconds] (or by assigning the field + /// directly) for contexts that need to interoperate with upstreams that + /// emit fractional ttls. + pub allow_float_seconds: bool, } /// Cacheability calculated from cache control. @@ -198,7 +252,20 @@ impl CacheControl { directives.insert(key.unwrap(), value); } } - Some(CacheControl { directives }) + Some(CacheControl { + directives, + allow_float_seconds: false, + }) + } + + /// Builder setter: enable fractional delta-seconds parsing. + /// + /// See [CacheControl::allow_float_seconds] for semantics. Returns `self` + /// so it can be chained onto a parser call, e.g. + /// `CacheControl::from_resp_headers(&resp).map(|cc| cc.with_float_seconds())`. + pub fn with_float_seconds(mut self) -> Self { + self.allow_float_seconds = true; + self } /// Parse from the given header name in `headers` @@ -282,7 +349,12 @@ impl CacheControl { fn parse_delta_seconds(&self, key: &str) -> Result> { if let Some(Some(dir_value)) = self.directives.get(key) { - Ok(Some(dir_value.parse_as_delta_seconds()?)) + let value = if self.allow_float_seconds { + dir_value.parse_as_delta_seconds_floor()? + } else { + dir_value.parse_as_delta_seconds()? + }; + Ok(Some(value)) } else { Ok(None) } @@ -869,4 +941,168 @@ mod tests { let cc = CacheControl::from_req_headers(&req).unwrap(); assert!(cc.only_if_cached()) } + + #[test] + fn test_parse_as_delta_seconds_floor() { + // Integer values behave identically to the strict parser + let v = DirectiveValue(b"10".to_vec()); + assert_eq!(v.parse_as_delta_seconds_floor().unwrap(), 10); + + let v = DirectiveValue(b"\"10\"".to_vec()); + assert_eq!(v.parse_as_delta_seconds_floor().unwrap(), 10); + + // Quoted fractional values are unwrapped by parse_as_str and floored. + let v = DirectiveValue(b"\"1.5\"".to_vec()); + assert_eq!(v.parse_as_delta_seconds_floor().unwrap(), 1); + + let v = DirectiveValue(b"0".to_vec()); + assert_eq!(v.parse_as_delta_seconds_floor().unwrap(), 0); + + // Integer positive overflow is still capped + let v = DirectiveValue(b"99999999999999999999".to_vec()); + assert_eq!( + v.parse_as_delta_seconds_floor().unwrap(), + DELTA_SECONDS_OVERFLOW_VALUE + ); + + // Floats are floored + let v = DirectiveValue(b"1.5".to_vec()); + assert_eq!(v.parse_as_delta_seconds_floor().unwrap(), 1); + + let v = DirectiveValue(b"1.9".to_vec()); + assert_eq!(v.parse_as_delta_seconds_floor().unwrap(), 1); + + let v = DirectiveValue(b"0.5".to_vec()); + assert_eq!(v.parse_as_delta_seconds_floor().unwrap(), 0); + + let v = DirectiveValue(b"3600.0".to_vec()); + assert_eq!(v.parse_as_delta_seconds_floor().unwrap(), 3600); + + // Float positive overflow is capped + let v = DirectiveValue(b"99999999999.5".to_vec()); + assert_eq!( + v.parse_as_delta_seconds_floor().unwrap(), + DELTA_SECONDS_OVERFLOW_VALUE + ); + + // Negative values are rejected (matches strict behavior) + assert!(DirectiveValue(b"-1".to_vec()) + .parse_as_delta_seconds_floor() + .is_err()); + assert!(DirectiveValue(b"-1.5".to_vec()) + .parse_as_delta_seconds_floor() + .is_err()); + + // Non-finite / non-numeric values are rejected + assert!(DirectiveValue(b"NaN".to_vec()) + .parse_as_delta_seconds_floor() + .is_err()); + assert!(DirectiveValue(b"inf".to_vec()) + .parse_as_delta_seconds_floor() + .is_err()); + assert!(DirectiveValue(b"abc".to_vec()) + .parse_as_delta_seconds_floor() + .is_err()); + + // Non-UTF8 bytes are rejected with the same utf-8 error as the strict parser. + let v = DirectiveValue(b"ba\xFFr".to_vec()); + assert_eq!( + v.parse_as_delta_seconds_floor() + .unwrap_err() + .context + .unwrap() + .to_string(), + "could not parse value as utf8", + ); + } + + #[test] + fn test_cache_control_allow_float_seconds_non_utf8_value() { + // Non-UTF8 bytes inside `max-age` should still produce the utf-8 error + // when the float-permitting flag is on, matching the strict parser. + let mut resp = response::Builder::new().body(()).unwrap(); + resp.headers_mut().insert( + CACHE_CONTROL, + HeaderValue::from_bytes(b"max-age=ba\xFFr").unwrap(), + ); + let (parts, _) = resp.into_parts(); + let cc = CacheControl::from_resp_headers(&parts) + .unwrap() + .with_float_seconds(); + assert_eq!( + cc.max_age().unwrap_err().context.unwrap().to_string(), + "could not parse value as utf8", + ); + } + + #[test] + fn test_cache_control_allow_float_seconds_default_off() { + // Default (strict) parsing: fractional values produce an error, and + // [InterpretCacheControl::fresh_duration] returns None, matching the + // pre-existing behavior. + let resp = build_response(CACHE_CONTROL, "max-age=10.7"); + let cc = CacheControl::from_resp_headers(&resp).unwrap(); + assert!(!cc.allow_float_seconds); + assert!(cc.max_age().is_err()); + assert!(cc.fresh_duration().is_none()); + } + + #[test] + fn test_cache_control_with_float_seconds() { + // `max-age` with a fractional value is floored when the flag is on. + let resp = build_response(CACHE_CONTROL, "max-age=10.7"); + let cc = CacheControl::from_resp_headers(&resp) + .unwrap() + .with_float_seconds(); + assert!(cc.allow_float_seconds); + assert_eq!(cc.max_age().unwrap().unwrap(), 10); + assert_eq!(cc.fresh_duration().unwrap(), Duration::from_secs(10)); + + // `s-maxage` still wins over `max-age` and is also floored. + let resp = build_response(CACHE_CONTROL, "s-maxage=3600.99, max-age=1800"); + let cc = CacheControl::from_resp_headers(&resp) + .unwrap() + .with_float_seconds(); + assert_eq!(cc.s_maxage().unwrap().unwrap(), 3600); + assert_eq!(cc.fresh_duration().unwrap(), Duration::from_secs(3600)); + + // `stale-while-revalidate` and `stale-if-error` also pick up flooring. + let resp = build_response( + CACHE_CONTROL, + "max-age=10, stale-while-revalidate=60.5, stale-if-error=30.9", + ); + let cc = CacheControl::from_resp_headers(&resp) + .unwrap() + .with_float_seconds(); + assert_eq!(cc.stale_while_revalidate().unwrap().unwrap(), 60); + assert_eq!(cc.stale_if_error().unwrap().unwrap(), 30); + assert_eq!( + cc.serve_stale_while_revalidate_duration().unwrap(), + Duration::from_secs(60) + ); + assert_eq!( + cc.serve_stale_if_error_duration().unwrap(), + Duration::from_secs(30) + ); + + // Integer values are unaffected when the flag is on. + let resp = build_response(CACHE_CONTROL, "max-age=12345"); + let cc = CacheControl::from_resp_headers(&resp) + .unwrap() + .with_float_seconds(); + assert_eq!(cc.fresh_duration().unwrap(), Duration::from_secs(12345)); + + // Invalid (non-numeric, negative) values still fail to parse under the flag. + let resp = build_response(CACHE_CONTROL, "max-age=abc"); + let cc = CacheControl::from_resp_headers(&resp) + .unwrap() + .with_float_seconds(); + assert!(cc.max_age().is_err()); + + let resp = build_response(CACHE_CONTROL, "max-age=-1.5"); + let cc = CacheControl::from_resp_headers(&resp) + .unwrap() + .with_float_seconds(); + assert!(cc.max_age().is_err()); + } } From a95f8c483fd769948053d9a31e895865d1239b06 Mon Sep 17 00:00:00 2001 From: Shane Utt Date: Sun, 12 Apr 2026 17:56:01 +0000 Subject: [PATCH 37/52] feat: make rustls cert public Includes-commit: 875e4d944fa71d8000d251e7d5c689c3de5f3546 Replicated-from: https://github.com/cloudflare/pingora/pull/858 Signed-off-by: Shane Utt --- .bleep | 2 +- pingora-core/src/utils/tls/rustls.rs | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.bleep b/.bleep index 0d1fd142..d3ce0bd6 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -26f2c0e18228072ec5d1e699d78e79cba56ae9b5 \ No newline at end of file +99c4b224507885ae7bf9bf00a6c4b2c2cef119a7 \ No newline at end of file diff --git a/pingora-core/src/utils/tls/rustls.rs b/pingora-core/src/utils/tls/rustls.rs index 429b3724..d4e4e5c9 100644 --- a/pingora-core/src/utils/tls/rustls.rs +++ b/pingora-core/src/utils/tls/rustls.rs @@ -101,17 +101,17 @@ pub struct CertKey { certificates: Vec, } -#[self_referencing] +#[self_referencing(pub_extras)] #[derive(Debug)] pub struct WrappedX509 { - raw_cert: Vec, + pub raw_cert: Vec, #[borrows(raw_cert)] #[covariant] - cert: X509Certificate<'this>, + pub cert: X509Certificate<'this>, } -fn parse_x509(raw_cert: &C) -> X509Certificate<'_> +pub fn parse_x509(raw_cert: &C) -> X509Certificate<'_> where C: AsRef<[u8]>, { From bc9870d49775bdbd0d3806139e9814ad59c5a782 Mon Sep 17 00:00:00 2001 From: Kevin Guthrie Date: Tue, 28 Apr 2026 12:00:45 -0400 Subject: [PATCH 38/52] Fix flaky test_connector_bind_to on macOS/CI The test connects to 240.0.0.1 (reserved) while bound to localhost and asserts the error is ConnectError or ConnectTimedout. On macOS and some CI runners the kernel returns ENETUNREACH immediately, which maps to ConnectNoRoute. Accept that as a valid outcome. This is the same class of fix applied to test_conn_timeout and test_tls_connect_timeout_supersedes_total in 542129f. --- .bleep | 2 +- pingora-core/src/connectors/mod.rs | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/.bleep b/.bleep index d3ce0bd6..467ea9ee 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -99c4b224507885ae7bf9bf00a6c4b2c2cef119a7 \ No newline at end of file +28f129ce840ca1dc8854cbcbdb850f533fdb1a94 \ No newline at end of file diff --git a/pingora-core/src/connectors/mod.rs b/pingora-core/src/connectors/mod.rs index 0e3c727c..35067fa3 100644 --- a/pingora-core/src/connectors/mod.rs +++ b/pingora-core/src/connectors/mod.rs @@ -613,8 +613,16 @@ mod tests { let stream = connector.new_stream(&peer).await; let error = stream.unwrap_err(); - // XXX: some systems will allow the socket to bind and connect without error, only to timeout - assert!(error.etype() == &ConnectError || error.etype() == &ConnectTimedout) + // The exact error varies by platform: Linux may return ConnectError, + // some systems time out (ConnectTimedout), and macOS/others may + // return ConnectNoRoute (ENETUNREACH) for unreachable addresses. + assert!( + error.etype() == &ConnectError + || error.etype() == &ConnectTimedout + || error.etype() == &ConnectNoRoute, + "unexpected error type: {:?}", + error.etype() + ) } /// Helper function for testing error handling in the `do_connect` function. From aece99322ac94f738c494c866ac3467aab663c34 Mon Sep 17 00:00:00 2001 From: Matthew Gumport Date: Fri, 1 May 2026 00:10:21 +0000 Subject: [PATCH 39/52] let h2 accept loop drain in-flight streams on shutdown MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The old loop used `tokio::select!` with a `poll_closed` path that bailed as soon as the shutdown signal fired. RFC 9113 §6.8 says we have to process streams below the final last_stream_id. We weren't doing that. Now we call `graceful_shutdown` on the connection, but streams that were already in the buffer or have a lower stream number get surfaced and dispatched normally. The loop exits once the codec flushes the closing GOAWAY. This also pulls the accept loop out of `apps/mod.rs` so that it's more easily testable and usable from a test environment. --- .bleep | 2 +- pingora-core/src/apps/mod.rs | 39 +-- pingora-core/src/protocols/http/v2/mod.rs | 349 ++++++++++++++++++- pingora-core/src/protocols/http/v2/server.rs | 72 ++++ 4 files changed, 432 insertions(+), 30 deletions(-) diff --git a/.bleep b/.bleep index 467ea9ee..5d1103fb 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -28f129ce840ca1dc8854cbcbdb850f533fdb1a94 \ No newline at end of file +9ac70aa780615adf143c1543e0e566e3b8f6d40f \ No newline at end of file diff --git a/pingora-core/src/apps/mod.rs b/pingora-core/src/apps/mod.rs index 82989e5c..4de4f9a9 100644 --- a/pingora-core/src/apps/mod.rs +++ b/pingora-core/src/apps/mod.rs @@ -20,7 +20,6 @@ use crate::server::ShutdownWatch; use async_trait::async_trait; use log::{debug, error}; use std::any::Any; -use std::future::poll_fn; use std::sync::Arc; use crate::protocols::http::v2::server; @@ -250,8 +249,7 @@ where }); let h2_options = self.h2_options(); - let h2_conn = server::handshake(stream, h2_options).await; - let mut h2_conn = match h2_conn { + let h2_conn = match server::handshake(stream, h2_options).await { Err(e) => { error!("H2 handshake error {e}"); return None; @@ -259,36 +257,21 @@ where Ok(c) => c, }; - let mut shutdown = shutdown.clone(); - loop { - // this loop ends when the client decides to close the h2 conn - // TODO: add a timeout? - let h2_stream = tokio::select! { - _ = shutdown.changed() => { - h2_conn.graceful_shutdown(); - let _ = poll_fn(|cx| h2_conn.poll_closed(cx)) - .await.map_err(|e| error!("H2 error waiting for shutdown {e}")); - return None; - } - h2_stream = server::HttpSession::from_h2_conn(&mut h2_conn, digest.clone()) => h2_stream - }; - let h2_stream = match h2_stream { - Err(e) => { - // It is common for the client to just disconnect TCP without properly - // closing H2. So we don't log the errors here - debug!("H2 error when accepting new stream {e}"); - return None; - } - Ok(s) => s?, // None means the connection is ready to be closed - }; - let app = self.clone(); - let shutdown = shutdown.clone(); + // The accept-loop body — including the graceful-shutdown state + // machine — lives in `server::accept_downstream_sessions` so that + // the same code path is exercised by tests in `protocols::http::v2`. + let app = self.clone(); + let shutdown_for_session = shutdown.clone(); + server::accept_downstream_sessions(h2_conn, digest, shutdown.clone(), |h2_stream| { + let app = app.clone(); + let shutdown = shutdown_for_session.clone(); pingora_runtime::current_handle().spawn(async move { // Note, `PersistentSettings` not currently relevant for h2 app.process_new_http(ServerSession::new_http2(h2_stream), &shutdown) .await; }); - } + }) + .await; } else if custom || matches!(stream.selected_alpn_proto(), Some(ALPN::Custom(_))) { return self.clone().process_custom_session(stream, shutdown).await; } else { diff --git a/pingora-core/src/protocols/http/v2/mod.rs b/pingora-core/src/protocols/http/v2/mod.rs index 615fcee5..8f664c9d 100644 --- a/pingora-core/src/protocols/http/v2/mod.rs +++ b/pingora-core/src/protocols/http/v2/mod.rs @@ -93,13 +93,14 @@ mod test { use h2::frame::*; use http::{HeaderMap, Method, Uri}; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt, DuplexStream}; + use tokio::sync::{oneshot, watch}; use tokio_stream::StreamExt; use pingora_http::{RequestHeader, ResponseHeader}; use pingora_timeout::sleep; use crate::protocols::{ - http::v2::server::{handshake, HttpSession}, + http::v2::server::{self, handshake, HttpSession}, Digest, }; @@ -274,4 +275,350 @@ mod test { assert!(handle.await.is_ok()); } } + + #[tokio::test] + async fn test_graceful_shutdown_processes_inflight_stream() { + // HEADERS arrive on the server after the shutdown signal + // fires, but before the client has observed GOAWAY. + let (mut client, server) = duplex(65536); + // Use channels for deterministic timing. + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let (write_headers_tx, write_headers_rx) = oneshot::channel::<()>(); + + let client_handle = tokio::spawn(async move { + client + .write_all(b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") + .await + .unwrap(); + let mut codec: h2::Codec = h2::Codec::new(client); + codec.send(Settings::default().into()).await.unwrap(); + codec.send(Settings::ack().into()).await.unwrap(); + + // Wait until the test has triggered shutdown on the server before + // sending HEADERS. See the function-level comment for why. + write_headers_rx.await.unwrap(); + + let mut headers = Headers::new( + 1.into(), + Pseudo::request( + Method::GET, + Uri::from_static("https://one.one.one.one/"), + None, + ), + HeaderMap::new(), + ); + headers.set_end_headers(); + headers.set_end_stream(); + codec.send(headers.into()).await.unwrap(); + + let mut saw_response = false; + let mut saw_goaway = false; + let _ = pingora_timeout::timeout(Duration::from_secs(5), async { + while let Some(frame) = codec.next().await { + match frame.unwrap() { + h2::frame::Frame::Headers(_) => { + saw_response = true; + } + h2::frame::Frame::GoAway(_) => { + saw_goaway = true; + } + _ => {} + } + if saw_response && saw_goaway { + break; + } + } + }) + .await; + + assert!(saw_response, "expected response for stream 1"); + assert!(saw_goaway, "expected at least one GOAWAY frame"); + }); + + let connection = handshake(Box::new(server), None).await.unwrap(); + let digest = Arc::new(Digest::default()); + + let trigger = tokio::spawn(async move { + sleep(Duration::from_millis(50)).await; + shutdown_tx.send(true).unwrap(); + // Wait long enough that the server task is guaranteed to have + // observed the shutdown signal and committed to its post-shutdown + // path before putting anything on the wire. + sleep(Duration::from_millis(50)).await; + write_headers_tx.send(()).unwrap(); + }); + + let mut session_handles = vec![]; + server::accept_downstream_sessions(connection, digest, shutdown_rx, |mut session| { + session_handles.push(tokio::spawn(async move { + let req = session.req_header(); + assert_eq!(req.method, Method::GET); + let resp = Box::new(ResponseHeader::build(200, None).unwrap()); + session.write_response_header(resp, true).unwrap(); + })); + }) + .await; + + trigger.await.unwrap(); + assert_eq!( + session_handles.len(), + 1, + "expected stream 1 to be surfaced after shutdown_initiated" + ); + for h in session_handles { + h.await.unwrap(); + } + client_handle.await.unwrap(); + } + + #[tokio::test] + async fn test_graceful_shutdown_processes_post_goaway_stream() { + // Client opens stream 1 after it has observed the + // server's GOAWAY frame. Stream 1 is below the GOAWAY(MAX) + // last_stream_id, so per RFC 9113 §6.8 the server must still process + // it. + let (mut client, server) = duplex(65536); + let (shutdown_tx, shutdown_rx) = watch::channel(false); + + let client_handle = tokio::spawn(async move { + client + .write_all(b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") + .await + .unwrap(); + let mut codec: h2::Codec = h2::Codec::new(client); + codec.send(Settings::default().into()).await.unwrap(); + codec.send(Settings::ack().into()).await.unwrap(); + + // Block until the server's GOAWAY is observed. HEADERS for + // stream 1 are sent strictly after this point. + let mut saw_goaway_before_headers = false; + while let Some(frame) = codec.next().await { + if matches!(frame.unwrap(), h2::frame::Frame::GoAway(_)) { + saw_goaway_before_headers = true; + break; + } + } + assert!( + saw_goaway_before_headers, + "expected GOAWAY before sending HEADERS for stream 1" + ); + + let mut headers = Headers::new( + 1.into(), + Pseudo::request( + Method::GET, + Uri::from_static("https://one.one.one.one/"), + None, + ), + HeaderMap::new(), + ); + headers.set_end_headers(); + headers.set_end_stream(); + codec.send(headers.into()).await.unwrap(); + + let mut saw_response_for_stream_1 = false; + let _ = pingora_timeout::timeout(Duration::from_secs(5), async { + while let Some(frame) = codec.next().await { + if let Ok(h2::frame::Frame::Headers(h)) = frame { + if h.stream_id() == 1u32 { + saw_response_for_stream_1 = true; + break; + } + } + } + }) + .await; + assert!( + saw_response_for_stream_1, + "expected response on stream 1 after GOAWAY", + ); + }); + + let connection = handshake(Box::new(server), None).await.unwrap(); + let digest = Arc::new(Digest::default()); + + let trigger = tokio::spawn(async move { + sleep(Duration::from_millis(50)).await; + shutdown_tx.send(true).unwrap(); + }); + + let mut session_handles = vec![]; + server::accept_downstream_sessions(connection, digest, shutdown_rx, |mut session| { + session_handles.push(tokio::spawn(async move { + let resp = Box::new(ResponseHeader::build(200, None).unwrap()); + session.write_response_header(resp, true).unwrap(); + })); + }) + .await; + + trigger.await.unwrap(); + assert_eq!( + session_handles.len(), + 1, + "expected exactly one stream surfaced after GOAWAY" + ); + for h in session_handles { + h.await.unwrap(); + } + client_handle.await.unwrap(); + } + + #[tokio::test] + async fn test_graceful_shutdown_idle_connection_exits_promptly() { + let (mut client, server) = duplex(65536); + let (shutdown_tx, shutdown_rx) = watch::channel(false); + + let client_handle = tokio::spawn(async move { + client + .write_all(b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") + .await + .unwrap(); + let mut codec: h2::Codec = h2::Codec::new(client); + codec.send(Settings::default().into()).await.unwrap(); + codec.send(Settings::ack().into()).await.unwrap(); + + // Wait for the server's GOAWAY, then drop the codec to close the + // connection so the accept loop can exit. + let mut saw_goaway = false; + let _ = pingora_timeout::timeout(Duration::from_secs(3), async { + while let Some(frame) = codec.next().await { + if matches!(frame.unwrap(), h2::frame::Frame::GoAway(_)) { + saw_goaway = true; + break; + } + } + }) + .await; + assert!(saw_goaway, "expected GOAWAY"); + }); + + let connection = handshake(Box::new(server), None).await.unwrap(); + let digest = Arc::new(Digest::default()); + + let trigger = tokio::spawn(async move { + sleep(Duration::from_millis(20)).await; + shutdown_tx.send(true).unwrap(); + }); + + let result = pingora_timeout::timeout( + Duration::from_secs(2), + server::accept_downstream_sessions(connection, digest, shutdown_rx, |_session| { + panic!("did not expect any sessions on an idle connection"); + }), + ) + .await; + assert!(result.is_ok(), "accept loop hung after shutdown"); + + trigger.await.unwrap(); + client_handle.await.unwrap(); + } + + #[tokio::test] + async fn test_graceful_shutdown_refuses_stream_above_last_stream_id() { + // After the server commits to a final last_stream_id and emits the + // closing GOAWAY, any stream the client tries to open above that id + // must be refused. The accept loop must not surface it and must exit + // cleanly via the `Ok(None)` arm of `from_h2_conn`. + let (mut client, server) = duplex(65536); + let (shutdown_tx, shutdown_rx) = watch::channel(false); + + let client_handle = tokio::spawn(async move { + client + .write_all(b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") + .await + .unwrap(); + let mut codec: h2::Codec = h2::Codec::new(client); + codec.send(Settings::default().into()).await.unwrap(); + codec.send(Settings::ack().into()).await.unwrap(); + + // Open stream 1 so the server will commit last_stream_id >= 1 + // when it eventually emits its closing GOAWAY. + let mut headers = Headers::new( + 1.into(), + Pseudo::request( + Method::GET, + Uri::from_static("https://one.one.one.one"), + None, + ), + HeaderMap::new(), + ); + headers.set_end_headers(); + headers.set_end_stream(); + codec.send(headers.into()).await.unwrap(); + + // Drain frames from the server. Break once we've seen the + // response and at least one GOAWAY so the test doesn't race + // its own outer timeout while waiting on a quiet codec. + let mut saw_response = false; + let mut saw_goaway = false; + let _ = pingora_timeout::timeout(Duration::from_secs(3), async { + while let Some(frame) = codec.next().await { + match frame { + Ok(h2::frame::Frame::Headers(h)) if h.stream_id() == 1 => { + saw_response = true; + } + Ok(h2::frame::Frame::GoAway(_)) => { + saw_goaway = true; + } + Ok(_) => {} + Err(_) => break, + } + if saw_response && saw_goaway { + break; + } + } + }) + .await; + assert!(saw_response, "expected response for stream 1"); + assert!(saw_goaway, "expected at least one GOAWAY frame"); + + // Try to open stream 3 (above last_stream_id). The send may + // succeed locally (duplex buffer) or fail (peer half closed); + // either way the server-side codec must not surface the stream. + let mut headers = Headers::new( + 3.into(), + Pseudo::request( + Method::GET, + Uri::from_static("https://one.one.one.one"), + None, + ), + HeaderMap::new(), + ); + headers.set_end_headers(); + headers.set_end_stream(); + let _ = codec.send(headers.into()).await; + }); + + let connection = handshake(Box::new(server), None).await.unwrap(); + let digest = Arc::new(Digest::default()); + + let trigger = tokio::spawn(async move { + sleep(Duration::from_millis(50)).await; + shutdown_tx.send(true).unwrap(); + }); + + let mut session_handles = vec![]; + let result = pingora_timeout::timeout( + Duration::from_secs(5), + server::accept_downstream_sessions(connection, digest, shutdown_rx, |mut session| { + session_handles.push(tokio::spawn(async move { + let resp = Box::new(ResponseHeader::build(200, None).unwrap()); + session.write_response_header(resp, true).unwrap(); + })); + }), + ) + .await; + assert!(result.is_ok(), "accept loop hung after shutdown"); + assert_eq!( + session_handles.len(), + 1, + "only stream 1 may be surfaced; streams above last_stream_id must be refused" + ); + + trigger.await.unwrap(); + for h in session_handles { + h.await.unwrap(); + } + client_handle.await.unwrap(); + } } diff --git a/pingora-core/src/protocols/http/v2/server.rs b/pingora-core/src/protocols/http/v2/server.rs index 363b7357..604d53c6 100644 --- a/pingora-core/src/protocols/http/v2/server.rs +++ b/pingora-core/src/protocols/http/v2/server.rs @@ -34,6 +34,7 @@ use crate::protocols::http::date::get_cached_date; use crate::protocols::http::v1::client::http_req_header_to_wire; use crate::protocols::http::HttpTask; use crate::protocols::{Digest, SocketAddr, Stream}; +use crate::server::ShutdownWatch; use crate::{Error, ErrorType, OrErr, Result}; const BODY_BUF_LIMIT: usize = 1024 * 64; @@ -63,6 +64,77 @@ pub async fn handshake(io: Stream, options: Option) -> Result( + mut conn: H2Connection, + digest: Arc, + mut shutdown: ShutdownWatch, + mut on_session: F, +) where + F: FnMut(HttpSession), +{ + let mut shutdown_initiated = false; + loop { + let h2_stream = if shutdown_initiated { + HttpSession::from_h2_conn(&mut conn, digest.clone()).await + } else { + tokio::select! { + // Poll the shutdown signal first so a concurrent signal is + // observed deterministically. `from_h2_conn` is cancel-safe + // and is polled again on the next iteration. + biased; + _ = shutdown.changed() => { + conn.graceful_shutdown(); + shutdown_initiated = true; + continue; + } + h2_stream = HttpSession::from_h2_conn(&mut conn, digest.clone()) => h2_stream, + } + }; + match h2_stream { + Err(e) => { + // It is common for the client to just disconnect TCP without + // properly closing H2. So we don't log the errors here + debug!("H2 error when accepting new stream {e}"); + return; + } + // None means the connection is ready to be closed + Ok(None) => return, + Ok(Some(session)) => on_session(session), + } + } +} + use futures::task::Context; use futures::task::Poll; use std::pin::Pin; From 06cbc1ca81018a95b3a0ba47c8bd3b5db15c4b88 Mon Sep 17 00:00:00 2001 From: Abhishek Aiyer Date: Fri, 1 May 2026 10:19:18 +0100 Subject: [PATCH 40/52] Derive Clone and Debug on HttpServerOptions --- .bleep | 2 +- pingora-core/src/apps/mod.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.bleep b/.bleep index 5d1103fb..53666e60 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -9ac70aa780615adf143c1543e0e566e3b8f6d40f \ No newline at end of file +182be87906bd49ad0910548b49efe9685a8a1792 \ No newline at end of file diff --git a/pingora-core/src/apps/mod.rs b/pingora-core/src/apps/mod.rs index 4de4f9a9..93bea8b4 100644 --- a/pingora-core/src/apps/mod.rs +++ b/pingora-core/src/apps/mod.rs @@ -57,7 +57,7 @@ pub trait ServerApp { async fn cleanup(&self) {} } #[non_exhaustive] -#[derive(Default)] +#[derive(Clone, Debug, Default)] /// HTTP Server options that control how the server handles some transport types. pub struct HttpServerOptions { /// Allow HTTP/2 for plaintext. From 043f1f604bec2dd0ad4a0a2e032c2c5109949b36 Mon Sep 17 00:00:00 2001 From: Edward Wang Date: Thu, 30 Apr 2026 09:52:06 -0700 Subject: [PATCH 41/52] Use power-of-two selection to balance eviction This is a trivially simple way to drive toward uniform weights between LRU shards if they are unbalanced. --- .bleep | 2 +- pingora-cache/src/eviction/lru.rs | 4 +- pingora-lru/src/lib.rs | 271 ++++++++++++++++++++++++++---- 3 files changed, 242 insertions(+), 35 deletions(-) diff --git a/.bleep b/.bleep index 53666e60..8db81aa1 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -182be87906bd49ad0910548b49efe9685a8a1792 \ No newline at end of file +5c050d84d66e44913eb6bc46f2b71aa1916ce77b \ No newline at end of file diff --git a/pingora-cache/src/eviction/lru.rs b/pingora-cache/src/eviction/lru.rs index 96285700..f7a03f99 100644 --- a/pingora-cache/src/eviction/lru.rs +++ b/pingora-cache/src/eviction/lru.rs @@ -72,7 +72,9 @@ impl Manager { self.0.shard_weight(shard) } - /// Get the number of items in a specific shard + /// Get the number of items in a specific shard. Best-effort + /// lock-free read; see [`pingora_lru::Lru::shard_len`] for the + /// consistency semantics. pub fn shard_len(&self, shard: usize) -> usize { self.0.shard_len(shard) } diff --git a/pingora-lru/src/lib.rs b/pingora-lru/src/lib.rs index 67f59230..b2c5a642 100644 --- a/pingora-lru/src/lib.rs +++ b/pingora-lru/src/lib.rs @@ -25,11 +25,16 @@ use linked_list::{LinkedList, LinkedListIter}; use hashbrown::HashMap; use parking_lot::RwLock; +use rand::Rng; use std::sync::atomic::{AtomicUsize, Ordering}; /// The LRU with `N` shards pub struct Lru { units: [RwLock>; N], + /// Lock-free `Relaxed` shadow of each shard's item count, backing + /// [`Lru::shard_len`] and the P2C selection in [`Lru::evict_to_limit`]. + /// Maintained alongside [`Lru::len`] at every count-mutating site. + shard_lens: [AtomicUsize; N], weight: AtomicUsize, weight_limit: usize, len_watermark: Option, @@ -59,11 +64,16 @@ impl Lru { ) -> Self { // use the unsafe code from ArrayVec just to init the array let mut units = arrayvec::ArrayVec::<_, N>::new(); + let mut shard_lens = arrayvec::ArrayVec::<_, N>::new(); for _ in 0..N { units.push(RwLock::new(LruUnit::with_capacity(capacity))); + shard_lens.push(AtomicUsize::new(0)); } Lru { units: units.into_inner().map_err(|_| "").unwrap(), + shard_lens: shard_lens + .into_inner() + .expect("shard_lens ArrayVec filled with exactly N elements"), weight: AtomicUsize::new(0), weight_limit, len_watermark, @@ -73,6 +83,23 @@ impl Lru { } } + /// Increment item-count bookkeeping for `shard`. Both atomics use + /// `Relaxed`; called while holding the shard write lock so that + /// `len` and `shard_lens[shard]` advance in lockstep. + #[inline] + fn incr_count(&self, shard: usize) { + self.len.fetch_add(1, Ordering::Relaxed); + self.shard_lens[shard].fetch_add(1, Ordering::Relaxed); + } + + /// Decrement item-count bookkeeping for `shard`. See + /// [`Self::incr_count`]. + #[inline] + fn decr_count(&self, shard: usize) { + self.len.fetch_sub(1, Ordering::Relaxed); + self.shard_lens[shard].fetch_sub(1, Ordering::Relaxed); + } + /// Admit the key value to the [Lru] /// /// Return the shard index which the asset is added to @@ -91,7 +118,7 @@ impl Lru { self.weight.fetch_sub(old_weight, Ordering::Relaxed); } else { // Assume old_weight == 0 means a new item is admitted - self.len.fetch_add(1, Ordering::Relaxed); + self.incr_count(shard); } } shard @@ -150,14 +177,23 @@ impl Lru { unit.write().access(key) } - /// Evict at most one item from the given shard + /// Evict at most one item from the given shard, identified by the + /// hash-like `shard` seed (mapped into `0..N` via `% N`). /// - /// Return the evicted asset and its size if there is anything to evict + /// Return the evicted asset and its size if there is anything to evict. pub fn evict_shard(&self, shard: u64) -> Option<(T, usize)> { - let evicted = self.units[get_shard(shard, N)].write().evict(); + self.evict_shard_at(get_shard(shard, N)) + } + + /// Evict at most one item from the shard at index `shard` (in `0..N`). + /// Internal entry point that skips the `% N` round-trip in + /// [`Self::evict_shard`]. + fn evict_shard_at(&self, shard: usize) -> Option<(T, usize)> { + assert!(shard < N); + let evicted = self.units[shard].write().evict(); if let Some((_, weight)) = evicted.as_ref() { self.weight.fetch_sub(*weight, Ordering::Relaxed); - self.len.fetch_sub(1, Ordering::Relaxed); + self.decr_count(shard); self.evicted_weight.fetch_add(*weight, Ordering::Relaxed); self.evicted_len.fetch_add(1, Ordering::Relaxed); } @@ -168,43 +204,92 @@ impl Lru { /// /// Return a list of evicted items. /// - /// The evicted items are randomly selected from all the shards. + /// Each iteration selects the shard to evict from using the "power of two + /// choices" strategy: two shards are picked uniformly at random and the + /// one with more items is chosen (see + /// ). This biases + /// eviction toward longer shards and drives [`Self::shard_len`] toward a + /// uniform distribution, which keeps per-shard serialization cost (e.g. + /// `pingora_cache::eviction::lru::Manager::serialize_shard`) bounded. + /// + /// Selection is by item count, not weight, even when eviction is + /// triggered by `weight_limit`. With heavily skewed item weights this + /// may evict more items than a weight-biased policy to reach the same + /// total weight — the tradeoff is intentional in favor of bounded + /// per-shard serialization cost. + /// + /// O(1) per iteration in the common case. If the chosen shard is + /// empty when we acquire its write lock (the Relaxed shadow may + /// not always reflect actual emptiness, and P2C may tie-break to + /// an empty shard when all shadow lengths are equal), we linearly + /// probe successive shard indices until one yields an item or we + /// wrap back to the starting shard — at which point every shard + /// was observed empty and we exit. Bounded by at most N probes + /// per outer iteration. pub fn evict_to_limit(&self) -> Vec<(T, usize)> { + self.evict_to_limit_with_rng(&mut rand::thread_rng()) + } + + /// Internal entry point for [`Self::evict_to_limit`] that lets tests + /// inject a seeded RNG for deterministic P2C selection. + fn evict_to_limit_with_rng(&self, rng: &mut R) -> Vec<(T, usize)> { let mut evicted = vec![]; let mut initial_weight = self.weight(); let mut initial_len = self.len(); - let mut shard_seed = rand::random(); // start from a random shard - let mut empty_shard = 0; - - // Entries can be admitted or removed from the LRU by others during the loop below - // Track initial size not to over evict due to entries admitted after the loop starts - // self.weight() / self.len() is also used not to over evict - // due to entries already removed by others - while ((initial_weight > self.weight_limit && self.weight() > self.weight_limit) + + // Transient over-limit weight can persist until the next + // admit/increment_weight call, which is acceptable because the + // next admission will re-trigger eviction. + while (initial_weight > self.weight_limit && self.weight() > self.weight_limit) || self .len_watermark - .is_some_and(|w| initial_len > w && self.len() > w)) - && empty_shard < N + .is_some_and(|w| initial_len > w && self.len() > w) { - if let Some(i) = self.evict_shard(shard_seed) { - initial_weight -= i.1; - initial_len = initial_len.saturating_sub(1); - evicted.push(i) + // Power of two choices: pick the longer of two random shards. + // N == 1 short-circuits the redundant second roll. + let start = if N <= 1 { + 0 } else { - empty_shard += 1; + let a = rng.gen_range(0..N); + let b = rng.gen_range(0..N); + if self.shard_len(a) >= self.shard_len(b) { + a + } else { + b + } + }; + // Try the chosen shard first; on a miss (empty or raced), + // linearly probe successive indices. Wrapping back to + // `start` means every shard was observed empty, so we exit. + let mut shard = start; + let evicted_one = loop { + if let Some(item) = self.evict_shard_at(shard) { + break Some(item); + } + shard = (shard + 1) % N; + if shard == start { + break None; + } + }; + match evicted_one { + Some(i) => { + initial_weight = initial_weight.saturating_sub(i.1); + initial_len = initial_len.saturating_sub(1); + evicted.push(i); + } + None => break, } - // move on to the next shard - shard_seed += 1; } evicted } /// Remove the given asset. pub fn remove(&self, key: u64) -> Option<(T, usize)> { - let removed = self.units[get_shard(key, N)].write().remove(key); + let shard = get_shard(key, N); + let removed = self.units[shard].write().remove(key); if let Some((_, weight)) = removed.as_ref() { self.weight.fetch_sub(*weight, Ordering::Relaxed); - self.len.fetch_sub(1, Ordering::Relaxed); + self.decr_count(shard); } removed } @@ -213,12 +298,10 @@ impl Lru { /// /// Useful to recreate an LRU in most-to-least order pub fn insert_tail(&self, key: u64, data: T, weight: usize) -> bool { - if self.units[get_shard(key, N)] - .write() - .insert_tail(key, data, weight) - { + let shard = get_shard(key, N); + if self.units[shard].write().insert_tail(key, data, weight) { self.weight.fetch_add(weight, Ordering::Relaxed); - self.len.fetch_add(1, Ordering::Relaxed); + self.incr_count(shard); true } else { false @@ -251,6 +334,9 @@ impl Lru { } /// Return the current total weight. + /// + /// Lock-free `Relaxed` load. Best-effort: not synchronized with + /// concurrent admissions or evictions on other threads. pub fn weight(&self) -> usize { self.weight.load(Ordering::Relaxed) } @@ -285,9 +371,15 @@ impl Lru { N } - /// Get the number of items inside a shard + /// Get the number of items inside a shard. + /// + /// Lock-free `Relaxed` load from a per-shard atomic shadow. Best-effort: + /// there is no cross-thread ordering between this and [`Self::len`], and + /// `Σ shard_len(i)` is not guaranteed to equal [`Self::len`] at any + /// given instant. Suitable for eviction-balance heuristics and + /// observability; not suitable for synchronization. pub fn shard_len(&self, shard: usize) -> usize { - self.units[shard].read().len() + self.shard_lens[shard].load(Ordering::Relaxed) } /// Get the weight (total size) inside a shard @@ -437,6 +529,7 @@ impl LruUnit { true } + #[cfg(test)] pub fn len(&self) -> usize { assert_eq!(self.lookup_table.len(), self.order.len()); self.lookup_table.len() @@ -620,7 +713,6 @@ mod test_lru { assert_eq!(lru.len(), 6); let evicted = lru.evict_to_limit(); - // NOTE: there is a low chance this test would fail see the TODO in evict_to_limit assert_eq!(lru.weight(), 6); assert_eq!(lru.len(), 3); assert_eq!(lru.evicted_weight(), 6); @@ -715,6 +807,119 @@ mod test_lru { assert!(!lru.insert_tail(6, 6, 7)); } + #[test] + fn test_evict_to_limit_p2c_bias() { + use rand::rngs::StdRng; + use rand::SeedableRng; + + // Shard 0 starts with 50 items, shard 1 with 10 (all weight 1). + // weight_limit=30 forces 30 evictions. P2C-by-length should pick + // shard 0 (the longer one) most of the time, driving toward + // balance. Expected share from shard 0: P2C ≈ 0.75 (P(shard 0) + // = 3/4 per pick while it stays longer), uniform ≈ 0.50, + // always-shortest ≈ 0.67 (capped by shard 1's 10 items), + // always-longest ≈ 1.0. The (0.65..0.95) window distinguishes + // P2C from uniform; the upper bound catches a degenerate + // always-longest regression. + const TRIALS: u64 = 50; + let mut total_from_shard0 = 0usize; + let mut total_evicted = 0usize; + + for seed in 0..TRIALS { + let lru = Lru::::with_capacity(30, 64); + for k in 0..50u64 { + // even keys → shard 0 + lru.admit(k * 2, k * 2, 1); + } + for k in 0..10u64 { + // odd keys → shard 1 + lru.admit(k * 2 + 1, k * 2 + 1, 1); + } + assert_eq!(lru.weight(), 60); + + let mut rng = StdRng::seed_from_u64(seed); + let evicted = lru.evict_to_limit_with_rng(&mut rng); + assert!( + lru.weight() <= 30, + "post-eviction weight {} exceeds limit", + lru.weight() + ); + total_from_shard0 += evicted.iter().filter(|(k, _)| k % 2 == 0).count(); + total_evicted += evicted.len(); + } + + assert!(total_evicted > 1000, "too few evictions: {total_evicted}"); + let share = total_from_shard0 as f64 / total_evicted as f64; + assert!( + (0.65..0.95).contains(&share), + "expected shard-0 eviction share in 0.65..0.95 (P2C ≈ 0.75); got {share}" + ); + } + + #[test] + fn test_evict_to_limit_break_on_empty_shards_over_limit() { + // Force `weight` above the limit while every shard is empty + // (simulating bookkeeping skew). The linear probe must wrap + // around all N shards and exit cleanly. + let lru = Lru::::with_capacity(10, 16); + lru.weight.fetch_add(100, Ordering::Relaxed); + assert_eq!(lru.evict_to_limit().len(), 0); + } + + #[test] + fn test_watermark_eviction_with_zero_weight_items() { + // All items have weight 0 so the weight-limit guard never fires; + // only the length watermark drives eviction. P2C-by-length should + // still reach the watermark regardless of weight values. + let lru = Lru::::with_capacity_and_watermark(usize::MAX / 2, 10, Some(2)); + for k in 0..6u64 { + lru.insert_tail(k, k, 0); + } + assert_eq!(lru.len(), 6); + assert_eq!(lru.weight(), 0); + let evicted = lru.evict_to_limit(); + assert_eq!(lru.len(), 2); + assert_eq!(evicted.len(), 4); + } + + #[test] + fn test_evict_to_limit_with_mostly_empty_shards() { + // 7/8 shards empty: both random rolls land on empty shards ~77% + // of the time, exercising the linear-probe fallback heavily. + let lru = Lru::::with_capacity(2, 16); + for k in 0..8u64 { + // multiples of 8 hash to shard 0 + lru.admit(k * 8, k * 8, 1); + } + assert_eq!(lru.weight(), 8); + + let evicted = lru.evict_to_limit(); + assert_eq!(lru.weight(), 2); + assert_eq!(evicted.len(), 6); + assert!(evicted.iter().all(|(k, _)| k % 8 == 0)); + } + + #[test] + fn test_evict_to_limit_below_limit_returns_immediately() { + // Smoke test: outer guard short-circuits when already under limit. + let lru = Lru::::with_capacity(0, 16); + assert_eq!(lru.evict_to_limit().len(), 0); + } + + #[test] + fn test_evict_to_limit_n1() { + // N=1 is a trivial special case in the selection logic; ensure + // basic eviction still works. + let lru = Lru::::with_capacity(2, 16); + for k in 0..5u64 { + lru.admit(k, k, 1); + } + assert_eq!(lru.weight(), 5); + let evicted = lru.evict_to_limit(); + assert_eq!(lru.weight(), 2); + assert_eq!(evicted.len(), 3); + } + #[test] fn test_watermark_eviction() { const WEIGHT_LIMIT: usize = usize::MAX / 2; From 2536867e3fc71e565a42e3531ea69eda49744f10 Mon Sep 17 00:00:00 2001 From: Ian Crutcher Date: Fri, 1 May 2026 13:16:34 -0500 Subject: [PATCH 42/52] Adding curves and second keyshare setting to httppeer hash --- .bleep | 2 +- pingora-core/src/upstreams/peer.rs | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.bleep b/.bleep index 8db81aa1..113ea0a7 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -5c050d84d66e44913eb6bc46f2b71aa1916ce77b \ No newline at end of file +a09076b1cc9dc4c07bc37373fd3ee02101ab64ad \ No newline at end of file diff --git a/pingora-core/src/upstreams/peer.rs b/pingora-core/src/upstreams/peer.rs index b5ec9d76..c7f5e40c 100644 --- a/pingora-core/src/upstreams/peer.rs +++ b/pingora-core/src/upstreams/peer.rs @@ -698,6 +698,8 @@ impl Hash for HttpPeer { // from the reuse hash for now. These are per-connection settings applied at handshake // time and may be revisited alongside other h2 settings that could be dynamically // adjusted over the lifetime of a connection. + self.options.curves.hash(state); + self.options.second_keyshare.hash(state); } } From ab48509e32d5849d9cf46cbc11aae0092d59e547 Mon Sep 17 00:00:00 2001 From: ewang Date: Wed, 6 May 2026 09:50:23 -0700 Subject: [PATCH 43/52] Add working_directory option for daemon mode This option is then passed to daemonize as the child process immediately runs chdir. --- .bleep | 2 +- docs/user_guide/conf.md | 1 + pingora-core/src/server/configuration/mod.rs | 31 ++++++++++++++++++++ pingora-core/src/server/daemon.rs | 8 +++-- 4 files changed, 39 insertions(+), 3 deletions(-) diff --git a/.bleep b/.bleep index 113ea0a7..c51df835 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -a09076b1cc9dc4c07bc37373fd3ee02101ab64ad \ No newline at end of file +1d7015f530f0547ee3fdaee8759930970e18ab0b \ No newline at end of file diff --git a/docs/user_guide/conf.md b/docs/user_guide/conf.md index 70a8f569..837dc519 100644 --- a/docs/user_guide/conf.md +++ b/docs/user_guide/conf.md @@ -23,6 +23,7 @@ group: webusers | threads | number of threads per service | number | | user | the user the pingora server should be run under after daemonization | string | | group | the group the pingora server should be run under after daemonization | string | +| working_directory | the working directory for the daemonized process | string | | client_bind_to_ipv4 | source IPv4 addresses to bind to when connecting to server | list of string | | client_bind_to_ipv6 | source IPv6 addresses to bind to when connecting to server| list of string | | ca_file | The path to the root CA file | string | diff --git a/pingora-core/src/server/configuration/mod.rs b/pingora-core/src/server/configuration/mod.rs index 1f410892..dd850713 100644 --- a/pingora-core/src/server/configuration/mod.rs +++ b/pingora-core/src/server/configuration/mod.rs @@ -26,6 +26,7 @@ use serde::{Deserialize, Serialize}; use std::ffi::OsString; use std::fs; use std::num::NonZeroU64; +use std::path::PathBuf; // default maximum upstream retries for retry-able proxy errors const DEFAULT_MAX_RETRIES: usize = 16; @@ -60,6 +61,11 @@ pub struct ServerConf { pub user: Option, /// Similar to `user`, the group this process should switch to. pub group: Option, + /// Working directory for the daemonized process. + /// + /// Only applied when `daemon` is `true`; set this to start the daemon from a known cwd. + // TODO: other OS path options should likely be `PathBuf` as well. + pub working_directory: Option, /// How many threads **each** service should get. The threads are not shared across services. pub threads: usize, /// Number of listener tasks to use per fd. This allows for parallel accepts. @@ -183,6 +189,7 @@ impl Default for ServerConf { upgrade_sock: "/tmp/pingora_upgrade.sock".to_string(), user: None, group: None, + working_directory: None, threads: 1, listener_tasks_per_fd: 1, work_stealing: true, @@ -357,6 +364,7 @@ mod tests { upgrade_sock: "".to_string(), user: None, group: None, + working_directory: None, threads: 1, listener_tasks_per_fd: 1, work_stealing: true, @@ -411,6 +419,29 @@ version: 1 assert_eq!("/tmp/pingora.pid", conf.pid_file); } + #[test] + fn test_working_directory_deserializes_from_yaml_string() { + init_log(); + let conf_str = r#" +--- +version: 1 +daemon: true +working_directory: /var/lib/pingora + "#; + + let conf = ServerConf::from_yaml(conf_str).unwrap(); + assert_eq!( + conf.working_directory.as_deref(), + Some(std::path::Path::new("/var/lib/pingora")) + ); + + let yaml = serde_yaml::to_value(&conf).unwrap(); + assert_eq!( + yaml.get("working_directory"), + Some(&serde_yaml::Value::String("/var/lib/pingora".to_string())) + ); + } + #[test] fn test_zero_max_blocking_threads_is_rejected() { init_log(); diff --git a/pingora-core/src/server/daemon.rs b/pingora-core/src/server/daemon.rs index b6c95cb0..d225ca22 100644 --- a/pingora-core/src/server/daemon.rs +++ b/pingora-core/src/server/daemon.rs @@ -385,12 +385,16 @@ fn process_is_running(pid: libc::pid_t) -> bool { /// Build a [`Daemonize`] instance configured from `conf`, without calling `start()` or /// `execute()`. The caller is responsible for driving execution. fn build_daemonize(conf: &ServerConf) -> Daemonize<()> { - // TODO: customize working dir - let daemonize = Daemonize::new() .umask(0o007) // allow same group to access files but not everyone else .pid_file(&conf.pid_file); + let daemonize = if let Some(working_directory) = conf.working_directory.as_ref() { + daemonize.working_directory(working_directory) + } else { + daemonize + }; + let daemonize = if let Some(error_log) = conf.error_log.as_ref() { let err = OpenOptions::new() .append(true) From 7d3677de90d84d441d62d21e5899fcfd033aac16 Mon Sep 17 00:00:00 2001 From: Kevin Guthrie Date: Thu, 7 May 2026 15:56:23 +0000 Subject: [PATCH 44/52] Ignore test_upload_connection_die due to timing dependency MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Problem `test_upload_connection_die` fails reliably in CI on both arm64 and x86. The test sends a 15MB upload to an nginx origin that immediately responds with 200, then kills the connection after 1s. Under CI load, the 15MB upload takes longer than 1s. When nginx sends the TCP RST, it discards the buffered 200 response (per TCP protocol semantics). The proxy sees an upstream error and resets the client connection, causing the test to fail with `ConnectionReset`. This is not a test bug — the proxy does not reliably forward early responses while still writing the request body upstream. The `select!` loop in `proxy_handle_upstream` is blocked on `send_body_to1` and cannot read the response concurrently. ## Fix Mark the test as `#[ignore]` with a detailed comment explaining the root cause. --- .bleep | 2 +- pingora-proxy/tests/test_upstream.rs | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/.bleep b/.bleep index c51df835..17fa6db5 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -1d7015f530f0547ee3fdaee8759930970e18ab0b \ No newline at end of file +cab5b1fa92196cceab015ec7ad3ad7b33aa9e9c6 \ No newline at end of file diff --git a/pingora-proxy/tests/test_upstream.rs b/pingora-proxy/tests/test_upstream.rs index 7e85c2f8..862009e5 100644 --- a/pingora-proxy/tests/test_upstream.rs +++ b/pingora-proxy/tests/test_upstream.rs @@ -72,6 +72,25 @@ async fn test_connection_die() { assert!(body.is_err()); } +// This test is ignored because it has a fundamental timing dependency. +// +// The nginx origin sends a 200 response and flushes it, then sleeps 1s +// and kills the connection (RST). The test expects the client to always +// receive the 200 before the connection dies. +// +// This fails under CI load because the 15MB request body takes longer +// than 1s to write. The proxy's select! loop is busy writing body chunks +// upstream and can't read the 200 response concurrently. When the 1s +// expires and nginx sends a TCP RST, the RST discards all buffered data +// (including the 200) per TCP semantics. The proxy then sees an upstream +// error and resets the client connection. +// +// The underlying issue is that TCP RST discards unread buffered data, +// so the 200 response is lost even though it was sent before the RST. +// Fixing this would require the proxy to read the response before or +// concurrently with the body write completing, which is a deeper +// architectural change. +#[ignore] #[tokio::test] async fn test_upload_connection_die() { init(); From 77cce2cdb50e20986ba20d00cf740e7aba473e8b Mon Sep 17 00:00:00 2001 From: Edward Wang Date: Mon, 4 May 2026 12:57:53 -0700 Subject: [PATCH 45/52] Add cancel-safe proxy task API for Subrequest server sessions Implement the same proxy task API functionality for subrequest server sessions as HTTP/1. Also fix the regular subrequest header write path so upgrade state is only marked after the 101 task is sent. --- .bleep | 2 +- pingora-core/src/protocols/http/server.rs | 25 +- .../src/protocols/http/subrequest/server.rs | 1848 ++++++++++++++++- 3 files changed, 1751 insertions(+), 124 deletions(-) diff --git a/.bleep b/.bleep index 17fa6db5..8ae0628e 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -cab5b1fa92196cceab015ec7ad3ad7b33aa9e9c6 \ No newline at end of file +1a80a4273bfdd2c261c6ab62019ec53b66c244bf \ No newline at end of file diff --git a/pingora-core/src/protocols/http/server.rs b/pingora-core/src/protocols/http/server.rs index 438f3cb0..bdf2bdc5 100644 --- a/pingora-core/src/protocols/http/server.rs +++ b/pingora-core/src/protocols/http/server.rs @@ -814,19 +814,24 @@ impl Session { /// Check if this session supports the cancel-safe proxy task API. /// - /// For HTTP/1.x, this can be toggled per-session via - /// [`set_proxy_tasks_enabled`](Self::set_proxy_tasks_enabled). + /// Currently supported by HTTP/1.x and Subrequest server sessions; + /// toggled per-session via [`set_proxy_tasks_enabled`](Self::set_proxy_tasks_enabled). pub fn supports_proxy_task_api(&self) -> bool { match self { Self::H1(s) => s.proxy_tasks_enabled(), - _ => false, + Self::Subrequest(s) => s.proxy_tasks_enabled(), + Self::H2(_) => false, + Self::Custom(_) => false, } } /// Enable or disable the cancel-safe proxy task API for this session. pub fn set_proxy_tasks_enabled(&mut self, enabled: bool) { - if let Self::H1(s) = self { - s.set_proxy_tasks_enabled(enabled); + match self { + Self::H1(s) => s.set_proxy_tasks_enabled(enabled), + Self::Subrequest(s) => s.set_proxy_tasks_enabled(enabled), + Self::H2(_) => {} + Self::Custom(_) => {} } } @@ -841,7 +846,7 @@ impl Session { match self { Self::H1(s) => s.send_proxy_task(task), Self::H2(_) => panic!("H2 proxy task API not yet implemented"), - Self::Subrequest(_) => panic!("Subrequest proxy task API not yet implemented"), + Self::Subrequest(s) => s.send_proxy_task(task), Self::Custom(_) => panic!("Custom proxy task API not yet implemented"), } } @@ -852,9 +857,9 @@ impl Session { pub fn has_pending_downstream_proxy_tasks(&self) -> bool { match self { Self::H1(s) => s.has_pending_proxy_tasks(), - Self::H2(_) => false, // TODO: implement for H2 - Self::Subrequest(_) => false, // TODO: implement for subrequests - Self::Custom(_) => false, // TODO: implement for custom + Self::H2(_) => false, // TODO: implement for H2 + Self::Subrequest(s) => s.has_pending_proxy_tasks(), + Self::Custom(_) => false, // TODO: implement for custom } } @@ -870,7 +875,7 @@ impl Session { match self { Self::H1(s) => s.write_proxy_tasks().await, Self::H2(_) => panic!("H2 proxy task API not yet implemented"), - Self::Subrequest(_) => panic!("Subrequest proxy task API not yet implemented"), + Self::Subrequest(s) => s.write_proxy_tasks().await, Self::Custom(_) => panic!("Custom proxy task API not yet implemented"), } } diff --git a/pingora-core/src/protocols/http/subrequest/server.rs b/pingora-core/src/protocols/http/subrequest/server.rs index c91dbf91..938c8c97 100644 --- a/pingora-core/src/protocols/http/subrequest/server.rs +++ b/pingora-core/src/protocols/http/subrequest/server.rs @@ -36,13 +36,14 @@ use bytes::Bytes; use http::HeaderValue; use http::{header, header::AsHeaderName, HeaderMap, Method}; use log::{debug, trace, warn}; -use pingora_error::{Error, ErrorType::*, OkOrErr, Result}; +use pingora_error::{Error, ErrorType::*, OkOrErr, OrErr, Result}; use pingora_http::{RequestHeader, ResponseHeader}; use pingora_timeout::timeout; +use std::collections::VecDeque; use std::time::Duration; use tokio::sync::{mpsc, oneshot}; -use super::body::{BodyReader, BodyWriter}; +use super::body::{BodyMode, BodyReader, BodyWriter, PREMATURE_BODY_END}; use crate::protocols::http::{ body_buffer::FixedBuffer, server::Session as GenericHttpSession, @@ -53,6 +54,47 @@ use crate::protocols::http::{ }; use crate::protocols::{Digest, SocketAddr}; +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] +enum StreamEndState { + #[default] + Open, + FinishRequired, + Finished, +} + +/// State for the cancel-safe proxy task write API. +#[derive(Default)] +struct ProxyTaskState { + /// Tasks remain queued until `Permit::send` commits them to the channel. + tasks: VecDeque, + /// Final `Done` is waiting for channel capacity. + finish_in_progress: bool, + /// End-of-stream observed from already-consumed tasks. + stream_end: StreamEndState, +} + +impl ProxyTaskState { + fn require_finish(&mut self) { + self.stream_end = StreamEndState::FinishRequired; + } + + fn mark_finished(&mut self) { + self.stream_end = StreamEndState::Finished; + } + + fn clear_stream_end(&mut self) { + self.stream_end = StreamEndState::Open; + } + + fn finish_required(&self) -> bool { + self.stream_end == StreamEndState::FinishRequired + } + + fn finished(&self) -> bool { + self.stream_end == StreamEndState::Finished + } +} + /// The HTTP server session pub struct HttpSession { // these are only options because we allow dropping them separately on shutdown @@ -76,6 +118,11 @@ pub struct HttpSession { // TODO: likely doesn't need to be a separate bool when/if moving away from dummy SessionV1 clear_request_body_headers: bool, digest: Option>, + /// Whether the cancel-safe proxy task API is enabled for this session. + /// Defaults to `false`. Toggle via [`set_proxy_tasks_enabled`](Self::set_proxy_tasks_enabled). + proxy_tasks_enabled: bool, + /// Cancel-safe proxy task state. + proxy_task_state: ProxyTaskState, } /// A handle to the subrequest session itself to interact or read from it. @@ -133,6 +180,8 @@ impl HttpSession { upgraded: false, clear_request_body_headers: false, digest: digest.map(Box::new), + proxy_tasks_enabled: false, + proxy_task_state: ProxyTaskState::default(), }, SubrequestHandle { tx: downstream_tx, @@ -310,12 +359,24 @@ impl HttpSession { // XXX: don't add additional downstream headers, unlike h1, subreq is mostly treated as a pipe - // Allow informational header (excluding 101) to pass through without affecting the state - // of the request + let mut upgrade_ok: Option = None; if header.status == 101 || !header.status.is_informational() { - // reset request body to done for incomplete upgrade handshakes - if let Some(upgrade_ok) = self.is_upgrade(&header) { - if upgrade_ok { + upgrade_ok = self.is_upgrade(&header); + } + + // TODO propagate h2 end + debug!("send response header (subrequest)"); + match self + .tx + .as_mut() + .expect("tx valid before shutdown") + .send(HttpTask::Header(header.clone(), false)) + .await + { + Ok(()) => { + self.init_body_writer(&header); + self.response_written = Some(*header); + if let Some(true) = upgrade_ok { debug!("ok upgrade handshake"); // For ws we use HTTP1_0 do_read_body_until_closed // @@ -342,25 +403,10 @@ impl HttpSession { // TODO: this has no effect resetting the body counter of TE chunked self.body_reader.convert_to_close_delimited(); } - } else { + } else if upgrade_ok == Some(false) { debug!("bad upgrade handshake!"); // continue to read body as-is, this is now just a regular request } - } - self.init_body_writer(&header); - } - - // TODO propagate h2 end - debug!("send response header (subrequest)"); - match self - .tx - .as_mut() - .expect("tx valid before shutdown") - .send(HttpTask::Header(header.clone(), false)) - .await - { - Ok(()) => { - self.response_written = Some(*header); Ok(()) } Err(e) => Error::e_because(WriteError, "writing response header", e), @@ -390,40 +436,12 @@ impl HttpSession { } fn init_body_writer(&mut self, header: &ResponseHeader) { - use http::StatusCode; - /* the following responses don't have body 204, 304, and HEAD */ - if matches!( - header.status, - StatusCode::NO_CONTENT | StatusCode::NOT_MODIFIED - ) || self.get_method() == Some(&Method::HEAD) - { - self.body_writer.init_content_length(0); - return; - } - - if header.status.is_informational() && header.status != StatusCode::SWITCHING_PROTOCOLS { - // 1xx response, not enough to init body - return; - } - - if self.is_upgrade(header) == Some(true) { - self.body_writer.init_close_delimited(); - } else if is_chunked_encoding_from_headers(&header.headers) { - // transfer-encoding takes priority over content-length - self.body_writer.init_close_delimited(); - } else { - let content_length = - header_value_content_length(header.headers.get(http::header::CONTENT_LENGTH)); - match content_length { - Some(length) => { - self.body_writer.init_content_length(length); - } - None => { - /* TODO: 1. connection: keepalive cannot be used, - 2. mark connection must be closed */ - self.body_writer.init_close_delimited(); - } - } + if let Some(mode) = body_mode_for_header( + header, + self.get_method(), + self.is_upgrade(header) == Some(true), + ) { + apply_body_mode(&mut self.body_writer, mode); } } @@ -800,6 +818,428 @@ impl HttpSession { ) .await } + + // Cancel-safe proxy task API. Unlike v1's partial `AsyncWrite` state + // machines, subrequest uses mpsc `reserve()` + synchronous `Permit::send`. + + /// Whether the cancel-safe proxy task API is enabled for this session. + pub fn proxy_tasks_enabled(&self) -> bool { + self.proxy_tasks_enabled + } + + /// Enable or disable the cancel-safe proxy task API for this session. + pub fn set_proxy_tasks_enabled(&mut self, enabled: bool) { + self.proxy_tasks_enabled = enabled; + } + + /// Queue a proxy task for cancel-safe writing. + pub fn send_proxy_task(&mut self, task: HttpTask) { + self.proxy_task_state.tasks.push_back(task); + } + + /// Whether there are pending proxy tasks queued for writing. + pub fn has_pending_proxy_tasks(&self) -> bool { + self.proxy_task_state.finish_in_progress + || self.proxy_task_state.finish_required() + || !self.proxy_task_state.tasks.is_empty() + } + + /// Write queued proxy tasks. Cancelling while waiting for channel capacity + /// leaves the current task in `proxy_task_state.tasks`. + pub async fn write_proxy_tasks(&mut self) -> Result { + loop { + if self.proxy_task_state.finished() { + self.proxy_task_state.tasks.clear(); + return Ok(true); + } + + if self.proxy_task_state.finish_in_progress { + self.finish_proxy_task().await?; + self.proxy_task_state.mark_finished(); + return Ok(true); + } + + let Some(front) = self.proxy_task_state.tasks.front() else { + break; + }; + + // Tasks with no underlying channel send: handle synchronously + // without reserving a permit. + match front { + HttpTask::Done => { + self.proxy_task_state.tasks.pop_front(); + self.proxy_task_state.require_finish(); + continue; + } + HttpTask::Failed(_) => { + let HttpTask::Failed(e) = self + .proxy_task_state + .tasks + .pop_front() + .expect("queue had a Failed task at the front") + else { + unreachable!() + }; + self.proxy_task_state.clear_stream_end(); + return Err(e); + } + _ => {} + } + + if let HttpTask::Header(_, header_end) = front { + let already_sent = match self.response_written.as_ref() { + Some(resp) => !resp.status.is_informational() || self.upgraded, + None => false, + }; + if already_sent { + warn!("Respond header is already sent, cannot send again (subrequest, proxy task)"); + if *header_end { + self.proxy_task_state.require_finish(); + } + self.proxy_task_state.tasks.pop_front(); + continue; + } + } + + if let Some((upgraded_task, body_end, no_data)) = match front { + HttpTask::Body(data, end) => { + Some((false, *end, data.as_ref().is_none_or(|d| d.is_empty()))) + } + HttpTask::UpgradedBody(data, end) => { + Some((true, *end, data.as_ref().is_none_or(|d| d.is_empty()))) + } + _ => None, + } { + if upgraded_task != self.upgraded { + if upgraded_task { + panic!("Unexpected UpgradedBody task received on un-upgraded downstream session (subrequest, proxy task)"); + } else { + panic!("Unexpected Body task received on upgraded downstream session (subrequest, proxy task)"); + } + } + + if body_end { + self.proxy_task_state.require_finish(); + } + + if no_data { + self.proxy_task_state.tasks.pop_front(); + continue; + } + + match self.body_writer.body_mode { + BodyMode::Complete(_) => { + self.proxy_task_state.tasks.pop_front(); + continue; + } + BodyMode::ContentLength(total, written) if written >= total => { + self.proxy_task_state.tasks.pop_front(); + continue; + } + BodyMode::ToSelect => { + self.proxy_task_state.tasks.pop_front(); + self.proxy_task_state.clear_stream_end(); + return Error::e_explain( + InternalError, + "subrequest body proxy task before header is sent", + ); + } + _ => {} + } + } + + // `reserve()` is the only cancellation point; the queued task is + // popped only after a permit is acquired. + let tx_ref = self + .tx + .as_ref() + .ok_or_else(|| Error::explain(InternalError, "subrequest tx already shut down"))?; + let permit = match self.write_timeout { + Some(t) => match timeout(t, tx_ref.reserve()).await { + Ok(res) => res.or_err(WriteError, "subrequest channel closed")?, + Err(_) => { + return Error::e_explain( + WriteTimedout, + format!("reserving subrequest channel slot, timeout: {t:?}"), + ); + } + }, + None => tx_ref + .reserve() + .await + .or_err(WriteError, "subrequest channel closed")?, + }; + + // From here until `permit.send`, no `.await`; dispatch is atomic. + let task = self + .proxy_task_state + .tasks + .pop_front() + .expect("queue non-empty"); + + match task { + HttpTask::Header(header, hdr_end) => { + if hdr_end { + self.proxy_task_state.require_finish(); + } + + let upgrade_ok = if header.status == 101 || !header.status.is_informational() { + let outcome = self.v1_inner.is_upgrade(&header); + let mode = body_mode_for_header( + &header, + self.v1_inner.get_method(), + outcome == Some(true), + ); + (outcome, mode) + } else { + (None, None) + }; + + permit.send(HttpTask::Header(header.clone(), false)); + + if let Some(mode) = upgrade_ok.1 { + apply_body_mode(&mut self.body_writer, mode); + } + self.response_written = Some(*header); + if let Some(true) = upgrade_ok.0 { + debug!("ok upgrade handshake (subrequest, proxy task)"); + self.upgraded = true; + if self.body_reader.need_init() { + self.init_body_reader(); + } else { + self.body_reader.convert_to_close_delimited(); + } + } else if upgrade_ok.0 == Some(false) { + debug!("bad upgrade handshake! (subrequest, proxy task)"); + } + } + + HttpTask::Body(data, end) => { + if end { + self.proxy_task_state.require_finish(); + } + dispatch_body_inline( + &mut self.body_writer, + &mut self.body_bytes_sent, + self.upgraded, + data, + /* upgraded_task = */ false, + permit, + )?; + } + HttpTask::UpgradedBody(data, end) => { + if end { + self.proxy_task_state.require_finish(); + } + dispatch_body_inline( + &mut self.body_writer, + &mut self.body_bytes_sent, + self.upgraded, + data, + /* upgraded_task = */ true, + permit, + )?; + } + + HttpTask::Trailer(trailers) => { + permit.send(HttpTask::Trailer(trailers)); + self.proxy_task_state.require_finish(); + } + + HttpTask::Done | HttpTask::Failed(_) => { + unreachable!("Done/Failed are handled above without reserving a permit") + } + } + } + + // Match `response_duplex_vec`: finish whenever any task signalled EOS. + if self.proxy_task_state.finish_required() || self.body_writer.finished() { + self.finish_proxy_task().await?; + self.proxy_task_state.mark_finished(); + return Ok(true); + } + + Ok(self.body_writer.finished()) + } + + async fn finish_proxy_task(&mut self) -> Result<()> { + if matches!( + &self.body_writer.body_mode, + BodyMode::Complete(_) | BodyMode::ToSelect + ) { + self.maybe_force_close_body_reader(); + self.proxy_task_state.finish_in_progress = false; + return Ok(()); + } + + if let BodyMode::ContentLength(total, written) = self.body_writer.body_mode { + if written < total { + self.body_writer.body_mode = BodyMode::Complete(written); + self.proxy_task_state.finish_in_progress = false; + self.proxy_task_state.clear_stream_end(); + return Error::e_explain( + PREMATURE_BODY_END, + format!( + "Content-length: {total} bytes written: {written} (subrequest, proxy task)" + ), + ); + } + } + + self.proxy_task_state.finish_in_progress = true; + self.dispatch_finish().await?; + self.proxy_task_state.finish_in_progress = false; + Ok(()) + } + + /// Dispatch the final `HttpTask::Done`, mirroring `body_writer::finish`. + async fn dispatch_finish(&mut self) -> Result<()> { + // Reserve cancel-safely, then synchronously update body_mode and send. + let tx_ref = self + .tx + .as_ref() + .ok_or_else(|| Error::explain(InternalError, "subrequest tx already shut down"))?; + let permit = match self.write_timeout { + Some(t) => match timeout(t, tx_ref.reserve()).await { + Ok(res) => res.or_err(WriteError, "subrequest channel closed")?, + Err(_) => { + return Error::e_explain( + WriteTimedout, + format!("reserving subrequest channel slot for finish, timeout: {t:?}"), + ); + } + }, + None => tx_ref + .reserve() + .await + .or_err(WriteError, "subrequest channel closed")?, + }; + + match self.body_writer.body_mode { + BodyMode::ContentLength(_total, written) => { + self.body_writer.body_mode = BodyMode::Complete(written); + permit.send(HttpTask::Done); + } + BodyMode::UntilClose(written) => { + self.body_writer.body_mode = BodyMode::Complete(written); + permit.send(HttpTask::Done); + } + BodyMode::Complete(_) => { + unreachable!("no-op body modes are handled before reserve") + } + BodyMode::ToSelect => { + unreachable!("no-op body modes are handled before reserve") + } + } + self.maybe_force_close_body_reader(); + Ok(()) + } +} + +fn body_mode_for_header( + header: &ResponseHeader, + method: Option<&Method>, + is_upgrade_ok: bool, +) -> Option { + use http::StatusCode; + if header.status.is_informational() && header.status != StatusCode::SWITCHING_PROTOCOLS { + return None; + } + if matches!( + header.status, + StatusCode::NO_CONTENT | StatusCode::NOT_MODIFIED + ) || method == Some(&Method::HEAD) + { + return Some(BodyMode::ContentLength(0, 0)); + } + if is_upgrade_ok || is_chunked_encoding_from_headers(&header.headers) { + Some(BodyMode::UntilClose(0)) + } else { + let content_length = + header_value_content_length(header.headers.get(http::header::CONTENT_LENGTH)); + match content_length { + Some(length) => Some(BodyMode::ContentLength(length, 0)), + None => Some(BodyMode::UntilClose(0)), + } + } +} + +fn apply_body_mode(body_writer: &mut BodyWriter, mode: BodyMode) { + match mode { + BodyMode::ContentLength(total, 0) => body_writer.init_content_length(total), + BodyMode::UntilClose(0) => body_writer.init_close_delimited(), + _ => body_writer.body_mode = mode, + } +} + +/// Body dispatch variant that avoids borrowing the whole session while a +/// channel permit borrows `self.tx`. +fn dispatch_body_inline( + body_writer: &mut BodyWriter, + body_bytes_sent: &mut usize, + upgraded: bool, + data: Option, + upgraded_task: bool, + permit: mpsc::Permit<'_, HttpTask>, +) -> Result<()> { + if upgraded_task != upgraded { + if upgraded_task { + panic!("Unexpected UpgradedBody task received on un-upgraded downstream session (subrequest, proxy task)"); + } else { + panic!("Unexpected Body task received on upgraded downstream session (subrequest, proxy task)"); + } + } + + let Some(d) = data else { + drop(permit); + return Ok(()); + }; + if d.is_empty() { + drop(permit); + return Ok(()); + } + + let (to_count, next_mode) = match &body_writer.body_mode { + BodyMode::ContentLength(total, written) => { + if written >= total { + drop(permit); + return Ok(()); + } + let remaining = *total - *written; + let to_write = if remaining < d.len() { + warn!("Trying to write data over content-length (subrequest, proxy task): {total}"); + remaining + } else { + d.len() + }; + ( + to_write, + BodyMode::ContentLength(*total, *written + to_write), + ) + } + BodyMode::UntilClose(written) => (d.len(), BodyMode::UntilClose(*written + d.len())), + BodyMode::Complete(_) => { + drop(permit); + return Ok(()); + } + BodyMode::ToSelect => { + drop(permit); + return Error::e_explain( + InternalError, + "subrequest body proxy task before header is sent", + ); + } + }; + + let to_send = if to_count < d.len() { + d.slice(..to_count) + } else { + d + }; + permit.send(HttpTask::Body(Some(to_send), false)); + body_writer.body_mode = next_mode; + *body_bytes_sent += to_count; + Ok(()) } #[cfg(test)] @@ -817,25 +1257,49 @@ mod tests_stream { let _ = env_logger::builder().is_test(true).try_init(); } + fn test_header(status: StatusCode) -> ResponseHeader { + ResponseHeader::build(status, None) + .expect("test status code should build a response header") + } + + fn recv_task(rx: &mut mpsc::Receiver) -> HttpTask { + rx.try_recv() + .expect("expected subrequest output task to be queued") + } + async fn session_from_input(input: &[u8]) -> (HttpSession, SubrequestHandle) { let mock_io = Builder::new().read(input).build(); let mut http_stream = GenericHttpSession::new_http1(Box::new(mock_io)); - http_stream.read_request().await.unwrap(); + http_stream + .read_request() + .await + .expect("test async operation should succeed"); let (mut http_stream, handle) = HttpSession::new_from_session(&http_stream); - http_stream.read_request().await.unwrap(); + http_stream + .read_request() + .await + .expect("test async operation should succeed"); (http_stream, handle) } - async fn build_upgrade_req(upgrade: &str, conn: &str) -> (HttpSession, SubrequestHandle) { + pub(super) async fn build_upgrade_req( + upgrade: &str, + conn: &str, + ) -> (HttpSession, SubrequestHandle) { let input = format!("GET / HTTP/1.1\r\nHost: pingora.org\r\nUpgrade: {upgrade}\r\nConnection: {conn}\r\n\r\n"); session_from_input(input.as_bytes()).await } - async fn build_req() -> (HttpSession, SubrequestHandle) { + pub(super) async fn build_req() -> (HttpSession, SubrequestHandle) { let input = "GET / HTTP/1.1\r\nHost: pingora.org\r\n\r\n".to_string(); session_from_input(input.as_bytes()).await } + pub(super) async fn build_head_req() -> (HttpSession, SubrequestHandle) { + let input = "HEAD / HTTP/1.1\r\nHost: pingora.org\r\n\r\n".to_string(); + session_from_input(input.as_bytes()).await + } + #[tokio::test] async fn read_basic() { init_log(); @@ -880,12 +1344,12 @@ mod tests_stream { async fn read_upgrade_req_with_1xx_response() { let (mut http_stream, _handle) = build_upgrade_req("websocket", "upgrade").await; assert!(http_stream.is_upgrade_req()); - let mut response = ResponseHeader::build(StatusCode::CONTINUE, None).unwrap(); + let mut response = test_header(StatusCode::CONTINUE); response.set_version(http::Version::HTTP_11); http_stream .write_response_header(Box::new(response)) .await - .unwrap(); + .expect("test operation should succeed"); // 100 won't affect body state assert!(http_stream.is_body_done()); } @@ -893,13 +1357,15 @@ mod tests_stream { #[tokio::test] async fn write() { let (mut http_stream, mut handle) = build_req().await; - let mut new_response = ResponseHeader::build(StatusCode::OK, None).unwrap(); - new_response.append_header("Foo", "Bar").unwrap(); + let mut new_response = test_header(StatusCode::OK); + new_response + .append_header("Foo", "Bar") + .expect("test operation should succeed"); http_stream .write_response_header_ref(&new_response) .await - .unwrap(); - match handle.rx.try_recv().unwrap() { + .expect("test operation should succeed"); + match recv_task(&mut handle.rx) { HttpTask::Header(header, end) => { assert_eq!(header.status, StatusCode::OK); assert_eq!(header.headers["foo"], "Bar"); @@ -912,12 +1378,12 @@ mod tests_stream { #[tokio::test] async fn write_informational() { let (mut http_stream, mut handle) = build_req().await; - let response_100 = ResponseHeader::build(StatusCode::CONTINUE, None).unwrap(); + let response_100 = test_header(StatusCode::CONTINUE); http_stream .write_response_header_ref(&response_100) .await - .unwrap(); - match handle.rx.try_recv().unwrap() { + .expect("test operation should succeed"); + match recv_task(&mut handle.rx) { HttpTask::Header(header, end) => { assert_eq!(header.status, StatusCode::CONTINUE); assert!(!end); @@ -925,12 +1391,12 @@ mod tests_stream { t => panic!("unexpected task {t:?}"), } - let response_200 = ResponseHeader::build(StatusCode::OK, None).unwrap(); + let response_200 = test_header(StatusCode::OK); http_stream .write_response_header_ref(&response_200) .await - .unwrap(); - match handle.rx.try_recv().unwrap() { + .expect("test operation should succeed"); + match recv_task(&mut handle.rx) { HttpTask::Header(header, end) => { assert_eq!(header.status, StatusCode::OK); assert!(!end); @@ -942,15 +1408,16 @@ mod tests_stream { #[tokio::test] async fn write_101_switching_protocol() { let (mut http_stream, mut handle) = build_upgrade_req("WebSocket", "Upgrade").await; - let mut response_101 = - ResponseHeader::build(StatusCode::SWITCHING_PROTOCOLS, None).unwrap(); - response_101.append_header("Foo", "Bar").unwrap(); + let mut response_101 = test_header(StatusCode::SWITCHING_PROTOCOLS); + response_101 + .append_header("Foo", "Bar") + .expect("test operation should succeed"); http_stream .write_response_header_ref(&response_101) .await - .unwrap(); + .expect("test operation should succeed"); - match handle.rx.try_recv().unwrap() { + match recv_task(&mut handle.rx) { HttpTask::Header(header, end) => { assert_eq!(header.status, StatusCode::SWITCHING_PROTOCOLS); assert!(!end); @@ -964,16 +1431,17 @@ mod tests_stream { .write_body(wire_body.clone()) .await .unwrap() - .unwrap(); + .expect("test operation should succeed"); assert_eq!(wire_body.len(), n); // this write should be ignored - let response_502 = ResponseHeader::build(StatusCode::BAD_GATEWAY, None).unwrap(); + let response_502 = ResponseHeader::build(StatusCode::BAD_GATEWAY, None) + .expect("test operation should succeed"); http_stream .write_response_header_ref(&response_502) .await - .unwrap(); + .expect("test operation should succeed"); - match handle.rx.try_recv().unwrap() { + match recv_task(&mut handle.rx) { HttpTask::Body(body, _end) => { assert_eq!(body.unwrap().len(), n); } @@ -989,12 +1457,14 @@ mod tests_stream { async fn write_body_cl() { let (mut http_stream, _handle) = build_req().await; let wire_body = Bytes::from(&b"a"[..]); - let mut new_response = ResponseHeader::build(StatusCode::OK, None).unwrap(); - new_response.append_header("Content-Length", "1").unwrap(); + let mut new_response = test_header(StatusCode::OK); + new_response + .append_header("Content-Length", "1") + .expect("test operation should succeed"); http_stream .write_response_header_ref(&new_response) .await - .unwrap(); + .expect("test operation should succeed"); assert_eq!( http_stream.body_writer.body_mode, BodyMode::ContentLength(1, 0) @@ -1003,29 +1473,37 @@ mod tests_stream { .write_body(wire_body.clone()) .await .unwrap() - .unwrap(); + .expect("test operation should succeed"); assert_eq!(wire_body.len(), n); - let n = http_stream.finish().await.unwrap().unwrap(); + let n = http_stream + .finish() + .await + .expect("test async operation should succeed") + .expect("test operation should succeed"); assert_eq!(wire_body.len(), n); } #[tokio::test] async fn write_body_until_close() { let (mut http_stream, _handle) = build_req().await; - let new_response = ResponseHeader::build(StatusCode::OK, None).unwrap(); + let new_response = test_header(StatusCode::OK); http_stream .write_response_header_ref(&new_response) .await - .unwrap(); + .expect("test operation should succeed"); assert_eq!(http_stream.body_writer.body_mode, BodyMode::UntilClose(0)); let wire_body = Bytes::from(&b"PAYLOAD"[..]); let n = http_stream .write_body(wire_body.clone()) .await .unwrap() - .unwrap(); + .expect("test operation should succeed"); assert_eq!(wire_body.len(), n); - let n = http_stream.finish().await.unwrap().unwrap(); + let n = http_stream + .finish() + .await + .expect("test async operation should succeed") + .expect("test operation should succeed"); assert_eq!(wire_body.len(), n); } @@ -1037,17 +1515,27 @@ mod tests_stream { let input3 = b"abc"; let mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build(); let mut http_stream = GenericHttpSession::new_http1(Box::new(mock_io)); - http_stream.read_request().await.unwrap(); + http_stream + .read_request() + .await + .expect("test async operation should succeed"); let (mut http_stream, handle) = HttpSession::new_from_session(&http_stream); - http_stream.read_request().await.unwrap(); + http_stream + .read_request() + .await + .expect("test async operation should succeed"); handle .tx .send(HttpTask::Body(Some(Bytes::from(&input3[..])), false)) .await - .unwrap(); + .expect("test operation should succeed"); assert_eq!(http_stream.get_path(), &b"/a?q=b%20c"[..]); - let res = http_stream.read_body().await.unwrap().unwrap(); + let res = http_stream + .read_body() + .await + .expect("test async operation should succeed") + .expect("test operation should succeed"); assert_eq!(res, &input3[..]); assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(3)); } @@ -1056,22 +1544,27 @@ mod tests_stream { async fn test_write_body_write_timeout() { let (mut http_stream, _handle) = build_req().await; http_stream.write_timeout = Some(Duration::from_millis(100)); - let mut new_response = ResponseHeader::build(StatusCode::OK, None).unwrap(); - new_response.append_header("Content-Length", "10").unwrap(); + let mut new_response = test_header(StatusCode::OK); + new_response + .append_header("Content-Length", "10") + .expect("test operation should succeed"); http_stream .write_response_header_ref(&new_response) .await - .unwrap(); + .expect("test operation should succeed"); let body_write_buf = Bytes::from(&b"abc"[..]); http_stream .write_body(body_write_buf.clone()) .await - .unwrap(); + .expect("test operation should succeed"); http_stream .write_body(body_write_buf.clone()) .await - .unwrap(); - http_stream.write_body(body_write_buf).await.unwrap(); + .expect("test operation should succeed"); + http_stream + .write_body(body_write_buf) + .await + .expect("test async operation should succeed"); // channel full let last_body = Bytes::from(&b"a"[..]); let res = http_stream.write_body(last_body).await; @@ -1081,8 +1574,11 @@ mod tests_stream { #[tokio::test] async fn test_write_continue_resp() { let (mut http_stream, mut handle) = build_req().await; - http_stream.write_continue_response().await.unwrap(); - match handle.rx.try_recv().unwrap() { + http_stream + .write_continue_response() + .await + .expect("test async operation should succeed"); + match recv_task(&mut handle.rx) { HttpTask::Header(header, end) => { assert_eq!(header.status, StatusCode::CONTINUE); assert!(!end); @@ -1095,7 +1591,10 @@ mod tests_stream { let mock_io = Builder::new().read(input).build(); let mut http_stream = GenericHttpSession::new_http1(Box::new(mock_io)); // Read the request in v1 inner session to set up headers properly - http_stream.read_request().await.unwrap(); + http_stream + .read_request() + .await + .expect("test async operation should succeed"); let (http_stream, handle) = HttpSession::new_from_session(&http_stream); (http_stream, handle) } @@ -1157,9 +1656,15 @@ mod tests_stream { async fn build_upgrade_req_with_body(header: &[u8]) -> (HttpSession, SubrequestHandle) { let mock_io = Builder::new().read(header).build(); let mut http_stream = GenericHttpSession::new_http1(Box::new(mock_io)); - http_stream.read_request().await.unwrap(); + http_stream + .read_request() + .await + .expect("test async operation should succeed"); let (mut http_stream, handle) = HttpSession::new_from_session(&http_stream); - http_stream.read_request().await.unwrap(); + http_stream + .read_request() + .await + .expect("test async operation should succeed"); (http_stream, handle) } @@ -1179,10 +1684,14 @@ mod tests_stream { .tx .send(HttpTask::Body(Some(Bytes::from(POST_BODY_DATA)), true)) .await - .unwrap(); + .expect("test operation should succeed"); let mut buf = vec![]; - while let Some(b) = http_stream.read_body_bytes().await.unwrap() { + while let Some(b) = http_stream + .read_body_bytes() + .await + .expect("test async operation should succeed") + { buf.put_slice(&b); } assert_eq!(buf, POST_BODY_DATA); @@ -1191,12 +1700,12 @@ mod tests_stream { assert!(http_stream.is_body_done()); - let mut response = ResponseHeader::build(StatusCode::SWITCHING_PROTOCOLS, None).unwrap(); + let mut response = test_header(StatusCode::SWITCHING_PROTOCOLS); response.set_version(http::Version::HTTP_11); http_stream .write_response_header(Box::new(response)) .await - .unwrap(); + .expect("test operation should succeed"); // body reader type switches assert!(!http_stream.is_body_done()); @@ -1206,15 +1715,1128 @@ mod tests_stream { .tx .send(HttpTask::Body(Some(Bytes::from(&ws_data[..])), false)) .await - .unwrap(); + .expect("test operation should succeed"); - let buf = http_stream.read_body_bytes().await.unwrap().unwrap(); + let buf = http_stream + .read_body_bytes() + .await + .expect("test async operation should succeed") + .expect("test operation should succeed"); assert_eq!(buf, ws_data.as_slice()); assert!(!http_stream.is_body_done()); // EOF ends body drop(handle.tx); - assert!(http_stream.read_body_bytes().await.unwrap().is_none()); + assert!(http_stream + .read_body_bytes() + .await + .expect("test async operation should succeed") + .is_none()); assert!(http_stream.is_body_done()); } } + +#[cfg(test)] +mod test_proxy_tasks { + //! Cancel-safe proxy task API tests for the subrequest server session. + + use super::tests_stream::{build_head_req, build_req, build_upgrade_req}; + use super::*; + use http::StatusCode; + + fn init_log() { + let _ = env_logger::builder().is_test(true).try_init(); + } + + fn test_header(status: StatusCode) -> ResponseHeader { + ResponseHeader::build(status, None) + .expect("test status code should build a response header") + } + + fn recv_task(rx: &mut mpsc::Receiver) -> HttpTask { + rx.try_recv() + .expect("expected subrequest output task to be queued") + } + + fn assert_rx_empty(rx: &mut mpsc::Receiver) { + assert!(matches!( + rx.try_recv(), + Err(mpsc::error::TryRecvError::Empty) + )); + } + + /// Dropping a blocked `send` does not deliver its value. + #[tokio::test(start_paused = true)] + async fn test_tokio_mpsc_send_cancel_drops_value() { + let (tx, mut rx) = mpsc::channel::(1); + tx.send(1) + .await + .expect("test async operation should succeed"); + let send_fut = tx.send(2); + tokio::pin!(send_fut); + tokio::select! { + biased; + _ = tokio::time::sleep(Duration::from_millis(10)) => {} + _ = &mut send_fut => panic!("expected the timer to win"), + }; + assert_eq!(rx.recv().await, Some(1)); + assert_eq!(rx.try_recv(), Err(mpsc::error::TryRecvError::Empty)); + } + + /// Dropping a blocked `reserve` does not consume capacity. + #[tokio::test(start_paused = true)] + async fn test_tokio_mpsc_reserve_cancel_releases_slot() { + let (tx, mut rx) = mpsc::channel::(1); + tx.send(1) + .await + .expect("test async operation should succeed"); + let reserve_fut = tx.reserve(); + tokio::pin!(reserve_fut); + tokio::select! { + biased; + _ = tokio::time::sleep(Duration::from_millis(10)) => {} + _ = &mut reserve_fut => panic!("expected the timer to win"), + }; + assert_eq!(rx.recv().await, Some(1)); + assert_eq!(rx.try_recv(), Err(mpsc::error::TryRecvError::Empty)); + } + + #[tokio::test] + async fn test_send_proxy_task_and_write() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + assert!(session.proxy_tasks_enabled()); + + let mut header = test_header(StatusCode::OK); + header + .insert_header("content-length", "5") + .expect("test operation should succeed"); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("hello")), true)); + + let end = session + .write_proxy_tasks() + .await + .expect("test async operation should succeed"); + assert!(end); + assert!(!session.has_pending_proxy_tasks()); + assert_eq!(session.body_bytes_sent(), 5); + + match recv_task(&mut handle.rx) { + HttpTask::Header(h, false) => assert_eq!(h.status, StatusCode::OK), + t => panic!("expected Header, got {t:?}"), + } + match recv_task(&mut handle.rx) { + HttpTask::Body(Some(b), false) => assert_eq!(&b[..], b"hello"), + t => panic!("expected Body, got {t:?}"), + } + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Done)); + assert_rx_empty(&mut handle.rx); + } + + #[tokio::test] + async fn test_informational_head_does_not_init_body_writer() { + init_log(); + + let (mut regular, mut regular_handle) = build_head_req().await; + regular + .write_response_header(Box::new(test_header(StatusCode::CONTINUE))) + .await + .expect("regular informational header write should succeed"); + assert_eq!(regular.body_writer.body_mode, BodyMode::ToSelect); + assert!(matches!( + recv_task(&mut regular_handle.rx), + HttpTask::Header(..) + )); + + let (mut proxy, mut proxy_handle) = build_head_req().await; + proxy.set_proxy_tasks_enabled(true); + proxy.send_proxy_task(HttpTask::Header( + Box::new(test_header(StatusCode::CONTINUE)), + false, + )); + assert!(!proxy + .write_proxy_tasks() + .await + .expect("proxy task write should succeed")); + assert_eq!(proxy.body_writer.body_mode, BodyMode::ToSelect); + assert!(matches!( + recv_task(&mut proxy_handle.rx), + HttpTask::Header(..) + )); + } + + #[tokio::test(start_paused = true)] + async fn test_proxy_task_with_timeout() { + init_log(); + // Do not drain `handle.rx` before the first write; the 5th response task blocks on capacity. + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + session.set_write_timeout(Some(Duration::from_millis(50))); + + let header = test_header(StatusCode::OK); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + for i in 0..5 { + session.send_proxy_task(HttpTask::Body( + Some(Bytes::from(format!("body-{i}"))), + i == 4, + )); + } + + let err = session + .write_proxy_tasks() + .await + .expect_err("full subrequest output channel should time out"); + assert_eq!(err.etype(), &WriteTimedout); + assert!(session.has_pending_proxy_tasks()); + + let mut delivered = Vec::new(); + while let Ok(task) = handle.rx.try_recv() { + delivered.push(task); + } + assert_eq!(delivered.len(), 4); + assert!(matches!(delivered[0], HttpTask::Header(..))); + + session.set_write_timeout(None); + let end = session + .write_proxy_tasks() + .await + .expect("retry after freeing channel capacity should complete"); + assert!(end); + while let Ok(task) = handle.rx.try_recv() { + delivered.push(task); + } + + assert_eq!(delivered.len(), 7); + assert!(matches!(delivered.last(), Some(HttpTask::Done))); + let body_count = delivered + .iter() + .filter(|t| matches!(t, HttpTask::Body(..))) + .count(); + assert_eq!(body_count, 5); + } + + #[tokio::test] + async fn test_proxy_task_channel_closed_errors() { + init_log(); + let (mut session, handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + drop(handle.rx); + + session.send_proxy_task(HttpTask::Header( + Box::new(test_header(StatusCode::OK)), + false, + )); + let err = session + .write_proxy_tasks() + .await + .expect_err("closed subrequest output channel should error"); + assert_eq!(err.etype(), &WriteError); + assert!(session.has_pending_proxy_tasks()); + } + + /// Repeatedly cancel while blocked on channel capacity, then verify the + /// receiver sees each queued task exactly once and in order. + #[tokio::test(start_paused = true)] + async fn test_proxy_task_cancel_safety() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let mut header = test_header(StatusCode::OK); + header + .insert_header("content-length", "5") + .expect("test operation should succeed"); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + for i in 0..4 { + session.send_proxy_task(HttpTask::Body( + Some(Bytes::from(vec![b'A' + i as u8; 1])), + false, + )); + } + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("E")), true)); + + let mut cancel_count = 0; + let mut delivered: Vec = Vec::new(); + loop { + if !session.has_pending_proxy_tasks() { + break; + } + tokio::select! { + biased; + _ = tokio::time::sleep(Duration::from_millis(5)) => { + cancel_count += 1; + while let Ok(task) = handle.rx.try_recv() { + delivered.push(task); + } + } + result = session.write_proxy_tasks() => { + result.expect("test operation should succeed"); + } + } + } + + assert!( + cancel_count >= 1, + "expected at least one cancellation during cancel-safe write, got {cancel_count}" + ); + + assert_eq!(session.proxy_task_state.tasks.len(), 0); + + while let Ok(task) = handle.rx.try_recv() { + delivered.push(task); + } + + assert!(matches!(delivered[0], HttpTask::Header(_, false))); + let mut body_bytes = Vec::new(); + let mut saw_done = false; + for task in &delivered[1..] { + match task { + HttpTask::Body(Some(b), false) => body_bytes.extend_from_slice(b), + HttpTask::Done => { + assert!(!saw_done, "Done delivered more than once"); + saw_done = true; + } + t => panic!("unexpected task in delivery: {t:?}"), + } + } + assert!(saw_done, "expected Done to be delivered"); + assert_eq!( + body_bytes, b"ABCDE", + "body chunks must arrive exactly once, in order" + ); + + assert_eq!(session.body_bytes_sent(), 5); + } + + /// `was_upgraded()` must remain false if the 101 send is cancelled before + /// it reaches the subrequest channel. + #[tokio::test(start_paused = true)] + async fn test_proxy_task_upgrade_consistency() { + init_log(); + let (mut session, mut handle) = build_upgrade_req("websocket", "Upgrade").await; + assert!(session.is_upgrade_req()); + session.set_proxy_tasks_enabled(true); + + // Four 1xx headers fill the upstream channel; the 101 then blocks. + for _ in 0..4 { + session.send_proxy_task(HttpTask::Header( + Box::new(test_header(StatusCode::CONTINUE)), + false, + )); + } + let mut h101 = test_header(StatusCode::SWITCHING_PROTOCOLS); + h101.set_version(http::Version::HTTP_11); + h101.insert_header("upgrade", "websocket") + .expect("test operation should succeed"); + h101.insert_header("connection", "Upgrade") + .expect("test operation should succeed"); + session.send_proxy_task(HttpTask::Header(Box::new(h101), false)); + + tokio::select! { + biased; + _ = tokio::time::sleep(Duration::from_millis(5)) => {} + _ = session.write_proxy_tasks() => panic!("expected reserve to be cancelled"), + }; + + assert!( + !session.was_upgraded(), + "was_upgraded must remain false until the 101 send actually completes" + ); + + for _ in 0..4 { + recv_task(&mut handle.rx); + } + session + .write_proxy_tasks() + .await + .expect("test async operation should succeed"); + assert!( + session.was_upgraded(), + "after the 101 send completes, was_upgraded must be true" + ); + } + + /// Same upgrade consistency check for the regular `write_response_header` + /// path, which also awaits on the subrequest output channel. + #[tokio::test(start_paused = true)] + async fn test_write_response_header_upgrade_cancel_consistency() { + init_log(); + let (mut session, mut handle) = build_upgrade_req("websocket", "Upgrade").await; + + for _ in 0..4 { + session + .write_response_header(Box::new(test_header(StatusCode::CONTINUE))) + .await + .expect("test operation should succeed"); + } + + let mut h101 = test_header(StatusCode::SWITCHING_PROTOCOLS); + h101.set_version(http::Version::HTTP_11); + h101.insert_header("upgrade", "websocket") + .expect("test operation should succeed"); + h101.insert_header("connection", "Upgrade") + .expect("test operation should succeed"); + + tokio::select! { + biased; + _ = tokio::time::sleep(Duration::from_millis(5)) => {} + _ = session.write_response_header(Box::new(h101.clone())) => { + panic!("expected header send to be cancelled") + } + }; + assert!(!session.was_upgraded()); + assert_eq!(session.body_writer.body_mode, BodyMode::ToSelect); + + for _ in 0..4 { + recv_task(&mut handle.rx); + } + session + .write_response_header(Box::new(h101)) + .await + .expect("test async operation should succeed"); + assert!(session.was_upgraded()); + } + + /// Trailers are dispatched correctly through `write_proxy_tasks`. + /// Matching regular `response_duplex_vec`, a final `Done` follows Trailer. + #[tokio::test] + async fn test_proxy_task_trailers() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let header = test_header(StatusCode::OK); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("hi")), false)); + let mut trailers = http::HeaderMap::new(); + trailers.insert("x-final", http::HeaderValue::from_static("done")); + session.send_proxy_task(HttpTask::Trailer(Some(Box::new(trailers)))); + + let end = session + .write_proxy_tasks() + .await + .expect("test async operation should succeed"); + assert!(end); + + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Header(..))); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Body(..))); + match recv_task(&mut handle.rx) { + HttpTask::Trailer(Some(t)) => { + assert_eq!( + t.get("x-final").expect("test trailer should be present"), + "done" + ); + } + t => panic!("expected Trailer, got {t:?}"), + } + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Done)); + assert_rx_empty(&mut handle.rx); + } + + #[tokio::test] + async fn test_proxy_task_trailer_before_content_length_complete_errors() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let mut header = test_header(StatusCode::OK); + header + .insert_header("content-length", "5") + .expect("test content-length header is valid"); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + let mut trailers = HeaderMap::new(); + trailers.insert("x-final", http::HeaderValue::from_static("done")); + session.send_proxy_task(HttpTask::Trailer(Some(Box::new(trailers)))); + + let err = session + .write_proxy_tasks() + .await + .expect_err("trailers before content-length body completion should error"); + assert_eq!(err.etype(), &PREMATURE_BODY_END); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Header(..))); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Trailer(..))); + assert_rx_empty(&mut handle.rx); + } + + #[tokio::test(start_paused = true)] + async fn test_proxy_task_trailer_cancel_safety() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + for _ in 0..4 { + session.send_proxy_task(HttpTask::Header( + Box::new(test_header(StatusCode::CONTINUE)), + false, + )); + } + let mut trailers = HeaderMap::new(); + trailers.insert("x-final", http::HeaderValue::from_static("done")); + session.send_proxy_task(HttpTask::Trailer(Some(Box::new(trailers)))); + + tokio::select! { + biased; + _ = tokio::time::sleep(Duration::from_millis(5)) => {} + _ = session.write_proxy_tasks() => panic!("expected trailer reserve to be cancelled"), + }; + + let mut prefix = Vec::new(); + while let Ok(task) = handle.rx.try_recv() { + prefix.push(task); + } + assert_eq!(prefix.len(), 4); + assert!(prefix.iter().all(|t| matches!(t, HttpTask::Header(..)))); + let end = session + .write_proxy_tasks() + .await + .expect("resume after trailer cancellation should complete"); + assert!(end); + + match recv_task(&mut handle.rx) { + HttpTask::Trailer(Some(t)) => { + assert_eq!(t.get("x-final").expect("trailer present"), "done") + } + t => panic!("expected Trailer after resume, got {t:?}"), + } + assert_rx_empty(&mut handle.rx); + } + + #[tokio::test(start_paused = true)] + async fn test_proxy_task_trailer_cancel_safety_after_body() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let header = test_header(StatusCode::OK); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + for _ in 0..3 { + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("x")), false)); + } + let mut trailers = HeaderMap::new(); + trailers.insert("x-final", http::HeaderValue::from_static("done")); + session.send_proxy_task(HttpTask::Trailer(Some(Box::new(trailers)))); + + tokio::select! { + biased; + _ = tokio::time::sleep(Duration::from_millis(5)) => {} + _ = session.write_proxy_tasks() => panic!("expected trailer reserve to be cancelled"), + }; + + let mut prefix = Vec::new(); + while let Ok(task) = handle.rx.try_recv() { + prefix.push(task); + } + assert_eq!(prefix.len(), 4); + assert!(matches!(prefix[0], HttpTask::Header(..))); + assert!(prefix[1..].iter().all(|t| matches!(t, HttpTask::Body(..)))); + let end = session + .write_proxy_tasks() + .await + .expect("resume after body trailer cancellation should complete"); + assert!(end); + + match recv_task(&mut handle.rx) { + HttpTask::Trailer(Some(t)) => { + assert_eq!(t.get("x-final").expect("trailer present"), "done") + } + t => panic!("expected Trailer after resume, got {t:?}"), + } + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Done)); + assert_rx_empty(&mut handle.rx); + } + + /// `body_bytes_sent` is only incremented after the synchronous + /// `Permit::send`, not on a cancelled `reserve().await`. + #[tokio::test(start_paused = true)] + async fn test_proxy_task_body_counter_no_double_count() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let mut header = test_header(StatusCode::OK); + header + .insert_header("content-length", "12") + .expect("test operation should succeed"); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("AAAA")), false)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("BBBB")), false)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("CCCC")), true)); + + let mut received_body_bytes = Vec::new(); + loop { + if !session.has_pending_proxy_tasks() { + break; + } + tokio::select! { + biased; + _ = tokio::time::sleep(Duration::from_millis(5)) => { + while let Ok(task) = handle.rx.try_recv() { + if let HttpTask::Body(Some(b), _) = &task { + received_body_bytes.extend_from_slice(b); + } + } + assert!( + session.body_bytes_sent() <= received_body_bytes.len(), + "body_bytes_sent ({}) must not exceed bytes actually delivered ({})", + session.body_bytes_sent(), + received_body_bytes.len(), + ); + } + result = session.write_proxy_tasks() => { + result.expect("test operation should succeed"); + } + } + } + + while let Ok(task) = handle.rx.try_recv() { + if let HttpTask::Body(Some(b), _) = &task { + received_body_bytes.extend_from_slice(b); + } + } + + assert_eq!(session.body_bytes_sent(), 12); + assert_eq!(&received_body_bytes[..], b"AAAABBBBCCCC"); + } + + /// Cancelling while reserving capacity for the final Done must leave the + /// finish operation resumable. + #[tokio::test(start_paused = true)] + async fn test_proxy_task_finish_cancel_safety() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let header = test_header(StatusCode::OK); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("a")), false)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("b")), false)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("c")), true)); + + // Header + three bodies fill the 4-slot channel. The final Done + // reserve blocks and is cancelled. + tokio::select! { + biased; + _ = tokio::time::sleep(Duration::from_millis(5)) => {} + _ = session.write_proxy_tasks() => panic!("expected finish reserve to be cancelled"), + }; + assert!(session.proxy_task_state.finish_in_progress); + + let first = recv_task(&mut handle.rx); + assert!(matches!(first, HttpTask::Header(..))); + + // Only one slot is available. Resuming must emit exactly one Done and + // return without trying to reserve a second slot. + let end = tokio::time::timeout(Duration::from_millis(5), session.write_proxy_tasks()) + .await + .expect("resume should not need a second channel slot") + .expect("test operation should succeed"); + assert!(end); + + let mut delivered = Vec::new(); + while let Ok(task) = handle.rx.try_recv() { + delivered.push(task); + } + assert_eq!(delivered.len(), 4); + assert!(matches!(delivered.last(), Some(HttpTask::Done))); + assert_eq!( + delivered + .iter() + .filter(|t| matches!(t, HttpTask::Done)) + .count(), + 1 + ); + } + + #[tokio::test] + async fn test_proxy_task_done_only_noops_without_channel() { + init_log(); + let (mut session, _handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + session.send_proxy_task(HttpTask::Done); + session.shutdown(); + + let end = session + .write_proxy_tasks() + .await + .expect("test async operation should succeed"); + assert!(end); + } + + #[tokio::test] + async fn test_proxy_task_done_only_noops_with_live_channel() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + session.send_proxy_task(HttpTask::Done); + + let end = session + .write_proxy_tasks() + .await + .expect("Done-only proxy task should complete without channel output"); + assert!(end); + assert!(!session.has_pending_proxy_tasks()); + assert_rx_empty(&mut handle.rx); + } + + #[tokio::test] + async fn test_proxy_task_header_only_end() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let header = test_header(StatusCode::NO_CONTENT); + session.send_proxy_task(HttpTask::Header(Box::new(header), true)); + + let end = session + .write_proxy_tasks() + .await + .expect("test async operation should succeed"); + assert!(end); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Header(..))); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Done)); + assert_rx_empty(&mut handle.rx); + } + + #[tokio::test] + async fn test_proxy_task_head_response_drops_body() { + init_log(); + let (mut session, mut handle) = build_head_req().await; + session.set_proxy_tasks_enabled(true); + + let mut header = test_header(StatusCode::OK); + header + .insert_header("content-length", "10") + .expect("test content-length header is valid"); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("not-sent")), true)); + + let end = session + .write_proxy_tasks() + .await + .expect("HEAD proxy task response should complete"); + assert!(end); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Header(..))); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Done)); + assert_rx_empty(&mut handle.rx); + assert_eq!(session.body_bytes_sent(), 0); + } + + #[tokio::test(start_paused = true)] + async fn test_proxy_task_duplicate_final_header_does_not_reserve() { + init_log(); + let (mut session, _handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let header = test_header(StatusCode::OK); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + session + .write_proxy_tasks() + .await + .expect("test async operation should succeed"); + + for _ in 0..3 { + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("x")), false)); + } + let duplicate = test_header(StatusCode::OK); + session.send_proxy_task(HttpTask::Header(Box::new(duplicate), false)); + + tokio::time::timeout(Duration::from_millis(5), session.write_proxy_tasks()) + .await + .expect("duplicate final header should be dropped without reserving") + .expect("test operation should succeed"); + assert!(!session.has_pending_proxy_tasks()); + } + + #[tokio::test] + async fn test_proxy_task_duplicate_final_header_preserves_end() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let header = test_header(StatusCode::OK); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + session + .write_proxy_tasks() + .await + .expect("test async operation should succeed"); + recv_task(&mut handle.rx); + + let duplicate = test_header(StatusCode::OK); + session.send_proxy_task(HttpTask::Header(Box::new(duplicate), true)); + + let end = session + .write_proxy_tasks() + .await + .expect("test async operation should succeed"); + assert!(end); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Done)); + assert_rx_empty(&mut handle.rx); + } + + #[tokio::test] + async fn test_proxy_task_failed_propagates_without_sending_later_tasks() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + session.send_proxy_task(HttpTask::Failed(Error::explain(InternalError, "boom"))); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("late")), true)); + + let err = session + .write_proxy_tasks() + .await + .expect_err("Failed proxy task should propagate error"); + assert_eq!(err.etype(), &InternalError); + assert!(session.has_pending_proxy_tasks()); + assert_rx_empty(&mut handle.rx); + } + + #[tokio::test] + async fn test_proxy_task_failed_clears_sticky_eos() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let header = test_header(StatusCode::OK); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + session.send_proxy_task(HttpTask::Body(None, true)); + session.send_proxy_task(HttpTask::Failed(Error::explain(InternalError, "boom"))); + + let err = session + .write_proxy_tasks() + .await + .expect_err("Failed after EOS should still propagate error"); + assert_eq!(err.etype(), &InternalError); + assert!(!session.has_pending_proxy_tasks()); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Header(..))); + assert_rx_empty(&mut handle.rx); + + let end = session + .write_proxy_tasks() + .await + .expect("retry after Failed should not emit stale Done"); + assert!(!end); + assert_rx_empty(&mut handle.rx); + } + + #[tokio::test] + async fn test_proxy_task_body_before_header_errors() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("body")), true)); + + let err = session + .write_proxy_tasks() + .await + .expect_err("body before response header should be rejected"); + assert_eq!(err.etype(), &InternalError); + assert!(!session.has_pending_proxy_tasks()); + assert_rx_empty(&mut handle.rx); + } + + #[tokio::test(start_paused = true)] + async fn test_proxy_task_none_body_does_not_reserve() { + init_log(); + let (mut session, _handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let header = test_header(StatusCode::OK); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + for _ in 0..3 { + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("x")), false)); + } + session.send_proxy_task(HttpTask::Body(None, false)); + + tokio::time::timeout(Duration::from_millis(5), session.write_proxy_tasks()) + .await + .expect("Body(None) should not reserve channel capacity") + .expect("test operation should succeed"); + assert!(!session.has_pending_proxy_tasks()); + } + + #[tokio::test(start_paused = true)] + async fn test_proxy_task_no_data_eos_survives_cancellation() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let header = test_header(StatusCode::OK); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + for _ in 0..3 { + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("x")), false)); + } + session.send_proxy_task(HttpTask::Body(None, true)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::new()), true)); + // This later body blocks on the full channel after the no-data EOS + // task has been consumed. The sticky EOS flag must survive that cancel. + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("y")), false)); + + tokio::select! { + biased; + _ = tokio::time::sleep(Duration::from_millis(5)) => {} + _ = session.write_proxy_tasks() => panic!("expected reserve after no-data EOS to be cancelled"), + }; + + let mut prefix = Vec::new(); + while let Ok(task) = handle.rx.try_recv() { + prefix.push(task); + } + assert_eq!(prefix.len(), 4); + assert!(matches!(prefix[0], HttpTask::Header(..))); + assert!(prefix[1..].iter().all(|t| matches!(t, HttpTask::Body(..)))); + let end = session + .write_proxy_tasks() + .await + .expect("resume after no-data EOS cancellation should complete"); + assert!(end); + + let mut delivered = Vec::new(); + while let Ok(task) = handle.rx.try_recv() { + delivered.push(task); + } + assert_eq!(delivered.len(), 2); + match &delivered[0] { + HttpTask::Body(Some(b), false) => assert_eq!(&b[..], b"y"), + t => panic!("expected Body(y) after resume, got {t:?}"), + } + assert!(matches!(delivered[1], HttpTask::Done)); + assert_eq!(session.body_bytes_sent(), 4); + } + + #[tokio::test] + async fn test_proxy_task_content_length_overrun_truncates() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let mut header = test_header(StatusCode::OK); + header + .insert_header("content-length", "3") + .expect("test operation should succeed"); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("abcdef")), true)); + + let end = session + .write_proxy_tasks() + .await + .expect("test async operation should succeed"); + assert!(end); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Header(..))); + match recv_task(&mut handle.rx) { + HttpTask::Body(Some(b), false) => assert_eq!(&b[..], b"abc"), + t => panic!("expected truncated Body, got {t:?}"), + } + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Done)); + assert_rx_empty(&mut handle.rx); + assert_eq!(session.body_bytes_sent(), 3); + } + + #[tokio::test] + async fn test_proxy_task_exact_content_length_without_end_sends_done() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let mut header = test_header(StatusCode::OK); + header + .insert_header("content-length", "3") + .expect("test content-length header is valid"); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("abc")), false)); + + let end = session + .write_proxy_tasks() + .await + .expect("exact content-length proxy task response should complete"); + assert!(end); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Header(..))); + match recv_task(&mut handle.rx) { + HttpTask::Body(Some(b), false) => assert_eq!(&b[..], b"abc"), + t => panic!("expected exact content-length Body, got {t:?}"), + } + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Done)); + assert_rx_empty(&mut handle.rx); + } + + #[tokio::test] + async fn test_proxy_task_late_tasks_after_finished_are_dropped() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let mut header = test_header(StatusCode::OK); + header + .insert_header("content-length", "3") + .expect("test content-length header is valid"); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("abc")), true)); + session + .write_proxy_tasks() + .await + .expect("initial response should complete"); + + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Header(..))); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Body(..))); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Done)); + + let mut trailers = HeaderMap::new(); + trailers.insert("x-late", http::HeaderValue::from_static("ignored")); + session.send_proxy_task(HttpTask::Trailer(Some(Box::new(trailers)))); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("late")), true)); + session.send_proxy_task(HttpTask::Failed(Error::explain(InternalError, "late"))); + + let end = session + .write_proxy_tasks() + .await + .expect("late tasks after finished stream should be dropped"); + assert!(end); + assert_rx_empty(&mut handle.rx); + assert!(!session.has_pending_proxy_tasks()); + } + + #[tokio::test] + async fn test_proxy_task_chunked_header_uses_until_close() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let mut header = test_header(StatusCode::OK); + header + .insert_header("transfer-encoding", "chunked") + .expect("test transfer-encoding header is valid"); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("chunk")), true)); + + let end = session + .write_proxy_tasks() + .await + .expect("chunked proxy task response should complete"); + assert!(end); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Header(..))); + match recv_task(&mut handle.rx) { + HttpTask::Body(Some(b), false) => assert_eq!(&b[..], b"chunk"), + t => panic!("expected chunked body task, got {t:?}"), + } + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Done)); + assert_eq!(session.body_bytes_sent(), 5); + } + + #[tokio::test] + async fn test_proxy_task_premature_content_length_errors_before_done() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let mut header = test_header(StatusCode::OK); + header + .insert_header("content-length", "5") + .expect("test operation should succeed"); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("hi")), true)); + + let err = session + .write_proxy_tasks() + .await + .expect_err("premature content-length should error before Done"); + assert_eq!(err.etype(), &PREMATURE_BODY_END); + assert_eq!(session.body_writer.body_mode, BodyMode::Complete(2)); + assert!(!session.has_pending_proxy_tasks()); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Header(..))); + match recv_task(&mut handle.rx) { + HttpTask::Body(Some(b), false) => assert_eq!(&b[..], b"hi"), + t => panic!("expected body before premature end, got {t:?}"), + } + assert_rx_empty(&mut handle.rx); + } + + #[tokio::test] + #[should_panic( + expected = "Unexpected UpgradedBody task received on un-upgraded downstream session" + )] + async fn test_upgraded_body_on_non_upgraded_session_panics() { + init_log(); + let (mut session, _handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + assert!(!session.was_upgraded()); + + let header = test_header(StatusCode::OK); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + session.send_proxy_task(HttpTask::UpgradedBody(Some(Bytes::from("ws")), true)); + + let _ = session.write_proxy_tasks().await; + } + + #[tokio::test] + #[should_panic(expected = "Unexpected Body task received on upgraded downstream session")] + async fn test_body_on_upgraded_session_panics() { + init_log(); + let (mut session, _handle) = build_upgrade_req("websocket", "Upgrade").await; + session.set_proxy_tasks_enabled(true); + + let mut h101 = test_header(StatusCode::SWITCHING_PROTOCOLS); + h101.set_version(http::Version::HTTP_11); + h101.insert_header("upgrade", "websocket") + .expect("test operation should succeed"); + h101.insert_header("connection", "Upgrade") + .expect("test operation should succeed"); + session.send_proxy_task(HttpTask::Header(Box::new(h101), false)); + session + .write_proxy_tasks() + .await + .expect("test async operation should succeed"); + + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("plain")), true)); + let _ = session.write_proxy_tasks().await; + } + + #[tokio::test] + async fn test_proxy_task_upgraded_body_happy_path() { + init_log(); + let (mut session, mut handle) = build_upgrade_req("websocket", "Upgrade").await; + session.set_proxy_tasks_enabled(true); + + let mut h101 = test_header(StatusCode::SWITCHING_PROTOCOLS); + h101.set_version(http::Version::HTTP_11); + h101.insert_header("upgrade", "websocket") + .expect("test upgrade header is valid"); + h101.insert_header("connection", "Upgrade") + .expect("test connection header is valid"); + session.send_proxy_task(HttpTask::Header(Box::new(h101), false)); + session.send_proxy_task(HttpTask::UpgradedBody(Some(Bytes::from("ws")), true)); + + let end = session + .write_proxy_tasks() + .await + .expect("upgraded proxy task response should complete"); + assert!(end); + assert!(session.was_upgraded()); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Header(..))); + match recv_task(&mut handle.rx) { + HttpTask::Body(Some(b), false) => assert_eq!(&b[..], b"ws"), + t => panic!("expected upgraded body task, got {t:?}"), + } + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Done)); + assert_eq!(session.body_bytes_sent(), 2); + } + + #[tokio::test] + #[should_panic( + expected = "Unexpected UpgradedBody task received on un-upgraded downstream session" + )] + async fn test_upgraded_body_on_non_upgraded_session_panics_while_full() { + init_log(); + let (mut session, _handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + for _ in 0..4 { + session.send_proxy_task(HttpTask::Header( + Box::new(test_header(StatusCode::CONTINUE)), + false, + )); + } + session.send_proxy_task(HttpTask::UpgradedBody(Some(Bytes::from("ws")), true)); + let _ = session.write_proxy_tasks().await; + } +} From c0845a8693b0792a6ccd0626e8475990f7269af2 Mon Sep 17 00:00:00 2001 From: Edward Wang Date: Mon, 4 May 2026 18:23:17 -0700 Subject: [PATCH 46/52] Add per-listener L4 buffer configuration Creates ListenerConfig to hold this new config and allow for future extensibility. --- .bleep | 2 +- pingora-core/src/listeners/mod.rs | 184 +++++++++++++++++++++++- pingora-core/src/protocols/l4/stream.rs | 82 +++++++++-- pingora-core/src/services/listening.rs | 7 +- 4 files changed, 256 insertions(+), 19 deletions(-) diff --git a/.bleep b/.bleep index 8ae0628e..d1911283 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -1a80a4273bfdd2c261c6ab62019ec53b66c244bf \ No newline at end of file +5281b97daa3213287999fb97c2a5de57ae565011 \ No newline at end of file diff --git a/pingora-core/src/listeners/mod.rs b/pingora-core/src/listeners/mod.rs index f2e649f8..5384dbd6 100644 --- a/pingora-core/src/listeners/mod.rs +++ b/pingora-core/src/listeners/mod.rs @@ -86,6 +86,9 @@ use std::{any::Any, fs::Permissions, sync::Arc}; use l4::{ListenerEndpoint, Stream as L4Stream}; use tls::{Acceptor, TlsSettings}; +pub use crate::protocols::l4::stream::{ + L4BufferSettings, DEFAULT_L4_READ_BUFFER_SIZE, DEFAULT_L4_WRITE_BUFFER_SIZE, +}; pub use crate::protocols::tls::ALPN; use crate::protocols::GetSocketDigest; pub use l4::{ServerAddress, TcpSocketOptions}; @@ -121,6 +124,7 @@ pub type TlsAcceptCallbacks = Box; struct TransportStackBuilder { l4: ServerAddress, tls: Option, + l4_buffer: L4BufferSettings, #[cfg(feature = "connection_filter")] connection_filter: Option>, } @@ -148,14 +152,94 @@ impl TransportStackBuilder { Ok(TransportStack { l4, tls: self.tls.take().map(|tls| Arc::new(tls.build())), + l4_buffer: self.l4_buffer, }) } } +/// Configuration for one listening endpoint. +/// +/// This configures the endpoint address and endpoint-specific transport +/// settings such as [`TcpSocketOptions`], [`TlsSettings`], and L4 +/// [`BufStream`](tokio::io::BufStream) buffer sizes. +pub struct ListenerConfig { + l4: ServerAddress, + tls: Option, + l4_buffer: L4BufferSettings, +} + +impl ListenerConfig { + /// Create a TCP listening endpoint config. + pub fn tcp(addr: impl Into) -> Self { + Self { + l4: ServerAddress::Tcp(addr.into(), None), + tls: None, + l4_buffer: L4BufferSettings::default(), + } + } + + /// Create a Unix domain socket listening endpoint config. + #[cfg(unix)] + pub fn uds(addr: impl Into) -> Self { + Self { + l4: ServerAddress::Uds(addr.into(), None), + tls: None, + l4_buffer: L4BufferSettings::default(), + } + } + + /// Set TCP socket options for this endpoint. + /// + /// # Panics + /// + /// Panics if this endpoint is not TCP. + #[track_caller] + pub fn tcp_socket_options(mut self, options: TcpSocketOptions) -> Self { + match &mut self.l4 { + ServerAddress::Tcp(_, opt) => *opt = Some(options), + #[cfg(unix)] + ServerAddress::Uds(_, _) => { + panic!("TCP socket options can only be set on TCP endpoints") + } + } + self + } + + /// Set Unix domain socket permissions for this endpoint. + /// + /// # Panics + /// + /// Panics if this endpoint is not a Unix domain socket. + #[cfg(unix)] + #[track_caller] + pub fn permissions(mut self, permissions: Permissions) -> Self { + match &mut self.l4 { + ServerAddress::Uds(_, perm) => *perm = Some(permissions), + ServerAddress::Tcp(_, _) => { + panic!("Unix domain socket permissions can only be set on UDS endpoints") + } + } + self + } + + /// Set TLS settings for this endpoint. + pub fn tls(mut self, settings: TlsSettings) -> Self { + self.tls = Some(settings); + self + } + + /// Set L4 `BufStream` buffer sizes for this endpoint. + pub fn l4_buffer(mut self, settings: L4BufferSettings) -> Self { + self.l4_buffer = settings; + self + } +} + #[derive(Clone)] pub(crate) struct TransportStack { l4: ListenerEndpoint, tls: Option>, + l4_buffer: L4BufferSettings, } impl TransportStack { @@ -168,6 +252,7 @@ impl TransportStack { Ok(UninitializedStream { l4: stream, tls: self.tls.clone(), + l4_buffer: self.l4_buffer, }) } @@ -179,11 +264,12 @@ impl TransportStack { pub(crate) struct UninitializedStream { l4: L4Stream, tls: Option>, + l4_buffer: L4BufferSettings, } impl UninitializedStream { pub async fn handshake(mut self) -> Result { - self.l4.set_buffer(); + self.l4.set_buffer(self.l4_buffer); if let Some(tls) = self.tls { let tls_stream = tls.tls_handshake(self.l4).await?; Ok(Box::new(tls_stream)) @@ -243,18 +329,22 @@ impl Listeners { /// Add a TCP endpoint to `self`. pub fn add_tcp(&mut self, addr: &str) { - self.add_address(ServerAddress::Tcp(addr.into(), None)); + self.add_listener(ListenerConfig::tcp(addr)); } /// Add a TCP endpoint to `self`, with the given [`TcpSocketOptions`]. pub fn add_tcp_with_settings(&mut self, addr: &str, sock_opt: TcpSocketOptions) { - self.add_address(ServerAddress::Tcp(addr.into(), Some(sock_opt))); + self.add_listener(ListenerConfig::tcp(addr).tcp_socket_options(sock_opt)); } /// Add a Unix domain socket endpoint to `self`. #[cfg(unix)] pub fn add_uds(&mut self, addr: &str, perm: Option) { - self.add_address(ServerAddress::Uds(addr.into(), perm)); + let endpoint = perm.map_or_else( + || ListenerConfig::uds(addr), + |perm| ListenerConfig::uds(addr).permissions(perm), + ); + self.add_listener(endpoint); } /// Add a TLS endpoint to `self` with the [Mozilla Intermediate](https://wiki.mozilla.org/Security/Server_Side_TLS#Intermediate_compatibility_.28recommended.29) @@ -272,7 +362,11 @@ impl Listeners { sock_opt: Option, settings: TlsSettings, ) { - self.add_endpoint(ServerAddress::Tcp(addr.into(), sock_opt), Some(settings)); + let mut endpoint = ListenerConfig::tcp(addr).tls(settings); + if let Some(sock_opt) = sock_opt { + endpoint = endpoint.tcp_socket_options(sock_opt); + } + self.add_listener(endpoint); } /// Add the given [`ServerAddress`] to `self`. @@ -294,11 +388,24 @@ impl Listeners { } } - /// Add the given [`ServerAddress`] to `self` with the given [`TlsSettings`] if provided + /// Add the given listener endpoint to `self`. + pub fn add_listener(&mut self, endpoint: ListenerConfig) { + let ListenerConfig { l4, tls, l4_buffer } = endpoint; + self.stacks.push(TransportStackBuilder { + l4, + tls, + l4_buffer, + #[cfg(feature = "connection_filter")] + connection_filter: self.connection_filter.clone(), + }); + } + + /// Add the given [`ServerAddress`] to `self` with the given [`TlsSettings`] if provided. pub fn add_endpoint(&mut self, l4: ServerAddress, tls: Option) { self.stacks.push(TransportStackBuilder { l4, tls, + l4_buffer: L4BufferSettings::default(), #[cfg(feature = "connection_filter")] connection_filter: self.connection_filter.clone(), }) @@ -372,6 +479,71 @@ mod test { TcpStream::connect(addrs[1]).await.unwrap(); } + #[test] + fn test_add_listener_config_tcp_l4_buffer() { + let mut listeners = Listeners::new(); + let tcp_options = TcpSocketOptions { + dscp: Some(10), + ..Default::default() + }; + let l4_buffer = L4BufferSettings { + read: Some(0), + write: None, + }; + + listeners.add_listener( + ListenerConfig::tcp("127.0.0.1:7107") + .tcp_socket_options(tcp_options) + .l4_buffer(l4_buffer), + ); + + assert_eq!(listeners.stacks.len(), 1); + assert_eq!(listeners.stacks[0].l4_buffer, l4_buffer); + assert_eq!(listeners.stacks[0].l4_buffer.read_capacity(), 0); + assert_eq!( + listeners.stacks[0].l4_buffer.write_capacity(), + DEFAULT_L4_WRITE_BUFFER_SIZE + ); + + match &listeners.stacks[0].l4 { + ServerAddress::Tcp(addr, Some(options)) => { + assert_eq!(addr, "127.0.0.1:7107"); + assert_eq!(options.dscp, Some(10)); + } + other => panic!("unexpected listener address: {other:?}"), + } + } + + #[cfg(unix)] + #[test] + fn test_add_listener_config_uds_l4_buffer() { + let mut listeners = Listeners::new(); + let l4_buffer = L4BufferSettings::unbuffered(); + + listeners.add_listener(ListenerConfig::uds("/tmp/test_builder_uds").l4_buffer(l4_buffer)); + + assert_eq!(listeners.stacks.len(), 1); + assert_eq!(listeners.stacks[0].l4_buffer, l4_buffer); + assert_eq!(listeners.stacks[0].l4_buffer.read_capacity(), 0); + assert_eq!(listeners.stacks[0].l4_buffer.write_capacity(), 0); + + match &listeners.stacks[0].l4 { + ServerAddress::Uds(addr, None) => assert_eq!(addr, "/tmp/test_builder_uds"), + other => panic!("unexpected listener address: {other:?}"), + } + } + + #[test] + fn test_l4_buffer_settings_defaults_per_direction() { + let l4_buffer = L4BufferSettings { + read: None, + write: Some(0), + }; + + assert_eq!(l4_buffer.read_capacity(), DEFAULT_L4_READ_BUFFER_SIZE); + assert_eq!(l4_buffer.write_capacity(), 0); + } + #[tokio::test] #[cfg(feature = "any_tls")] async fn test_listen_tls() { diff --git a/pingora-core/src/protocols/l4/stream.rs b/pingora-core/src/protocols/l4/stream.rs index ddbaceb1..7cbbd37c 100644 --- a/pingora-core/src/protocols/l4/stream.rs +++ b/pingora-core/src/protocols/l4/stream.rs @@ -354,14 +354,69 @@ impl AsRawSocket for RawStreamWrapper { } } -// Large read buffering helps reducing syscalls with little trade-off -// Ssl layer always does "small" reads in 16k (TLS record size) so L4 read buffer helps a lot. -const BUF_READ_SIZE: usize = 64 * 1024; -// Small write buf to match MSS. Too large write buf delays real time communication. -// This buffering effectively implements something similar to Nagle's algorithm. -// The benefit is that user space can control when to flush, where Nagle's can't be controlled. -// And userspace buffering reduce both syscalls and small packets. -const BUF_WRITE_SIZE: usize = 1460; +/// The default L4 read buffer size. +/// +/// Large read buffering helps reducing syscalls with little trade-off. The SSL +/// layer always does "small" reads in 16k chunks (TLS record size), so L4 read +/// buffering helps a lot. +pub const DEFAULT_L4_READ_BUFFER_SIZE: usize = 64 * 1024; + +/// The default L4 write buffer size. +/// +/// Small write buffering matches a typical MSS. Too large a write buffer delays +/// real-time communication. This buffering effectively implements something +/// similar to Nagle's algorithm, but user space can control when to flush. +pub const DEFAULT_L4_WRITE_BUFFER_SIZE: usize = 1460; + +/// L4 [`BufStream`] buffer sizing. +/// +/// Leaving either side as `None` preserves Pingora's default for that side. +/// Setting either side to `Some(0)` disables `BufStream` buffering for that +/// direction. +#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] +pub struct L4BufferSettings { + /// Read buffer size in bytes. `None` uses [`DEFAULT_L4_READ_BUFFER_SIZE`]. + pub read: Option, + /// Write buffer size in bytes. `None` uses [`DEFAULT_L4_WRITE_BUFFER_SIZE`]. + pub write: Option, +} + +impl L4BufferSettings { + /// Create settings with both read and write buffer sizes set explicitly. + pub fn new(read: usize, write: usize) -> Self { + Self { + read: Some(read), + write: Some(write), + } + } + + /// Create settings that disable both read and write `BufStream` buffering. + pub fn unbuffered() -> Self { + Self::new(0, 0) + } + + /// Set the read buffer size. + pub fn read(mut self, read: usize) -> Self { + self.read = Some(read); + self + } + + /// Set the write buffer size. + pub fn write(mut self, write: usize) -> Self { + self.write = Some(write); + self + } + + /// Resolved read buffer size after applying defaults. + pub fn read_capacity(&self) -> usize { + self.read.unwrap_or(DEFAULT_L4_READ_BUFFER_SIZE) + } + + /// Resolved write buffer size after applying defaults. + pub fn write_capacity(&self) -> usize { + self.write.unwrap_or(DEFAULT_L4_WRITE_BUFFER_SIZE) + } +} // NOTE: with writer buffering, users need to call flush() to make sure the data is actually // sent. Otherwise data could be stuck in the buffer forever or get lost when stream is closed. @@ -456,13 +511,18 @@ impl Stream { /// Set the buffer of BufStream /// It is only set later because of the malloc overhead in critical accept() path - pub(crate) fn set_buffer(&mut self) { + pub(crate) fn set_buffer(&mut self, buffer: L4BufferSettings) { use std::mem; // Since BufStream doesn't provide an API to adjust the buf directly, // we take the raw stream out of it and put it in a new BufStream with the size we want let stream = mem::take(&mut self.stream); - let stream = - stream.map(|s| BufStream::with_capacity(BUF_READ_SIZE, BUF_WRITE_SIZE, s.into_inner())); + let stream = stream.map(|s| { + BufStream::with_capacity( + buffer.read_capacity(), + buffer.write_capacity(), + s.into_inner(), + ) + }); let _ = mem::replace(&mut self.stream, stream); } } diff --git a/pingora-core/src/services/listening.rs b/pingora-core/src/services/listening.rs index 7b718b9b..1810ba0c 100644 --- a/pingora-core/src/services/listening.rs +++ b/pingora-core/src/services/listening.rs @@ -23,7 +23,7 @@ use crate::listeners::tls::TlsSettings; #[cfg(feature = "connection_filter")] use crate::listeners::AcceptAllFilter; use crate::listeners::{ - ConnectionFilter, Listeners, ServerAddress, TcpSocketOptions, TransportStack, + ConnectionFilter, ListenerConfig, Listeners, ServerAddress, TcpSocketOptions, TransportStack, }; use crate::protocols::Stream; #[cfg(unix)] @@ -123,6 +123,11 @@ impl Service { self.listeners.add_tcp(addr); } + /// Add a listening endpoint configured by a [`ListenerConfig`]. + pub fn add_listener(&mut self, endpoint: ListenerConfig) { + self.listeners.add_listener(endpoint); + } + /// Add a TCP listening endpoint with the given [`TcpSocketOptions`]. pub fn add_tcp_with_settings(&mut self, addr: &str, sock_opt: TcpSocketOptions) { self.listeners.add_tcp_with_settings(addr, sock_opt); From 5ff6afa73748e59c75ee774d00ecc3820de230b3 Mon Sep 17 00:00:00 2001 From: David Papp Date: Mon, 1 Jun 2026 12:33:54 +0200 Subject: [PATCH 47/52] feat(pingora-core): expose PROXY v2 extension-TLV callback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `maybe_consume_proxy_header` already parses PROXY v2 headers in `UninitializedStream::handshake` and threads the recovered source address into the SocketDigest, but it silently drops every parsed extension TLV. Consumer apps that ride application-defined metadata through the same header (HAProxy v2 spec § 2.2 reserves type IDs 0xE0..=0xEF for that) had no way to receive them. Add a global callback registration parallel to the existing `set_client_hello_callback`: pub type ProxyV2TlvCallback = Option; pub fn set_proxy_v2_tlv_callback(callback: ProxyV2TlvCallback); `maybe_consume_proxy_header` invokes the callback with the parsed `extensions` slice and the recovered source `SocketAddr` whenever the PROXY v2 header carried any TLVs. No-op when the callback isn't registered or the TLV list is empty, so existing deployments are unaffected. Re-exports `proxy_protocol::version2::ExtensionTlv` through `crate::protocols::proxy_protocol::ExtensionTlv` so callbacks can pattern-match on `ExtensionTlv::Custom { type_id, value }` without depending on the underlying proxy-protocol crate directly. Depends on the `Custom` variant added in gen0sec/proxy-protocol#12 — pingora-core's proxy-protocol dep is temporarily pinned to that branch; flip back to `main` once the PR lands. Use case: synapse-proxy's TLS-passthrough edge will encode per-flow JA4 fingerprints as a 0xE0 Custom TLV in the v2 header it already emits; the Tier-2 proxy receives them via this callback and populates its fingerprint cache without an out-of-band store. gen0sec/synapse#352. --- pingora-core/Cargo.toml | 7 +- pingora-core/src/lib.rs | 5 ++ pingora-core/src/listeners/mod.rs | 80 +++++++++++++++++++- pingora-core/src/protocols/proxy_protocol.rs | 6 ++ 4 files changed, 96 insertions(+), 2 deletions(-) diff --git a/pingora-core/Cargo.toml b/pingora-core/Cargo.toml index 5bbe983b..75e46187 100644 --- a/pingora-core/Cargo.toml +++ b/pingora-core/Cargo.toml @@ -73,7 +73,12 @@ x509-parser = { version = "0.16.0", optional = true } ouroboros = { version = "0.18.4", optional = true } lru = { workspace = true, optional = true } daggy = "0.8" -proxy-protocol = {git = "https://github.com/arxignis/proxy-protocol.git"} +# Point at the gen0sec/proxy-protocol fork's feat/custom-tlv-range +# branch which adds the `ExtensionTlv::Custom { type_id, value }` +# variant required by `set_proxy_v2_tlv_callback` below. Bump back +# to `main` (or a tagged version) once +# https://github.com/gen0sec/proxy-protocol/pull/12 lands. +proxy-protocol = { git = "https://github.com/gen0sec/proxy-protocol.git", branch = "feat/custom-tlv-range" } [target.'cfg(unix)'.dependencies] daemonize = "0.5.0" nix = "~0.24.3" diff --git a/pingora-core/src/lib.rs b/pingora-core/src/lib.rs index a4450632..b66282c9 100644 --- a/pingora-core/src/lib.rs +++ b/pingora-core/src/lib.rs @@ -105,6 +105,11 @@ pub mod utils; pub use listeners::set_client_hello_callback; pub use listeners::ClientHelloCallback; +// Re-export PROXY v2 extension-TLV callback for app-defined metadata +// (e.g. synapse's per-flow fingerprint store, see gen0sec/synapse#352). +pub use listeners::set_proxy_v2_tlv_callback; +pub use listeners::ProxyV2TlvCallback; + pub use pingora_error::{ErrorType::*, *}; // If both openssl and boringssl are enabled, prefer boringssl. diff --git a/pingora-core/src/listeners/mod.rs b/pingora-core/src/listeners/mod.rs index abc65ea1..18233070 100644 --- a/pingora-core/src/listeners/mod.rs +++ b/pingora-core/src/listeners/mod.rs @@ -129,6 +129,70 @@ fn call_client_hello_callback( } } +/// Callback function type for parsed PROXY protocol v2 extension TLVs. +/// +/// Invoked from `UninitializedStream::handshake` right after +/// `maybe_consume_proxy_header` parses a v2 header and recovers the +/// real client `SocketAddr`. Lets consumer applications (synapse, +/// custom logging) pull metadata out of application-defined TLVs +/// (HAProxy spec § 2.2 reserves type IDs `0xE0..=0xEF`) without +/// needing to parse the wire bytes themselves. +/// +/// Signature mirrors `ClientHelloCallback`: a pointer fn so the +/// registration can stay `Sync` without an `Arc` shuffle. +pub type ProxyV2TlvCallback = + Option; + +/// Global callback for parsed PROXY v2 extension TLVs. Registered by +/// `set_proxy_v2_tlv_callback`, invoked by `call_proxy_v2_tlv_callback`. +static PROXY_V2_TLV_CALLBACK: std::sync::OnceLock> = + std::sync::OnceLock::new(); + +/// Register a callback that receives the parsed PROXY v2 extension +/// TLVs from every accepted connection (in addition to the recovered +/// source address that `maybe_consume_proxy_header` already plumbs +/// into the SocketDigest). +/// +/// # Example +/// ``` +/// use pingora_core::listeners::set_proxy_v2_tlv_callback; +/// use pingora_core::protocols::proxy_protocol::ExtensionTlv; +/// use pingora_core::protocols::l4::socket::SocketAddr; +/// +/// set_proxy_v2_tlv_callback(Some(|tlvs: &[ExtensionTlv], real_addr: SocketAddr| { +/// for tlv in tlvs { +/// if let ExtensionTlv::Custom { type_id: 0xE0, value } = tlv { +/// // Application-defined TLV — decode and act on it. +/// let _ = (real_addr, value); +/// } +/// } +/// })); +/// ``` +pub fn set_proxy_v2_tlv_callback(callback: ProxyV2TlvCallback) { + PROXY_V2_TLV_CALLBACK.get_or_init(|| std::sync::Mutex::new(callback)); + if let Ok(mut cb) = PROXY_V2_TLV_CALLBACK.get().unwrap().lock() { + *cb = callback; + } +} + +/// Invoke the PROXY v2 TLV callback if registered. No-op when no +/// callback is set (the common case) or when the TLV list is empty. +fn call_proxy_v2_tlv_callback( + tlvs: &[proxy_protocol::ExtensionTlv], + real_addr: SocketAddr, +) { + if tlvs.is_empty() { + return; + } + if let Some(cb_guard) = PROXY_V2_TLV_CALLBACK.get() { + if let Ok(cb) = cb_guard.lock() { + if let Some(callback) = *cb { + callback(tlvs, real_addr); + } + } + } +} + #[cfg(unix)] use crate::server::ListenFds; @@ -367,7 +431,8 @@ impl UninitializedStream { match proxy_protocol::consume_proxy_header(&mut self.l4).await { Ok(Some(header)) => { - if let Some(real_addr) = proxy_protocol::source_addr_from_header(&header) { + let real_addr_opt = proxy_protocol::source_addr_from_header(&header); + if let Some(real_addr) = real_addr_opt { if let Some(digest) = self.l4.get_socket_digest() { let client_addr = SocketAddr::Inet(real_addr); digest.set_client_addr(client_addr.clone()); @@ -385,6 +450,19 @@ impl UninitializedStream { } else { debug!("PROXY protocol header contained no client address (LOCAL command)"); } + // Surface any v2 extension TLVs to the registered + // application callback so consumers can ride extra + // metadata (e.g. JA4 fingerprints) on the same hop + // without an out-of-band channel. We need the real + // source address as the key, so skip when source + // recovery failed. + if let Some(real_addr) = real_addr_opt { + if let proxy_protocol::ProxyHeader::Version2 { extensions, .. } = &header { + if !extensions.is_empty() { + call_proxy_v2_tlv_callback(extensions, SocketAddr::Inet(real_addr)); + } + } + } } Ok(None) => { debug!("PROXY protocol is enabled but downstream connection from {} sent no header (connection will continue)", peer_str); diff --git a/pingora-core/src/protocols/proxy_protocol.rs b/pingora-core/src/protocols/proxy_protocol.rs index a6700993..bcddd4bf 100644 --- a/pingora-core/src/protocols/proxy_protocol.rs +++ b/pingora-core/src/protocols/proxy_protocol.rs @@ -26,6 +26,12 @@ use crate::protocols::l4::stream::Stream; /// Re-export the parsed header type from the underlying `proxy_protocol` crate for convenience. pub use proxy_protocol::ProxyHeader; +/// Re-export the v2 extension TLV enum so callers of +/// `set_proxy_v2_tlv_callback` (registered in `listeners::mod`) can +/// pattern-match on `ExtensionTlv::Custom { type_id, value }` for +/// application-defined TLVs in the `0xE0..=0xEF` range. +pub use proxy_protocol::version2::ExtensionTlv; + /// Maximum number of bytes a PROXY protocol v1 header can occupy, including CRLF. pub const MAX_PROXY_V1_HEADER_LEN: usize = 108; /// Maximum number of bytes a PROXY protocol v2 header can occupy (16 bytes header + 64k body). From 810a9e3a80459886511d50610e8aabcc0622043c Mon Sep 17 00:00:00 2001 From: David Papp Date: Mon, 1 Jun 2026 12:53:55 +0200 Subject: [PATCH 48/52] chore: cargo fmt on the proxy-v2-tlv-callback additions CI's `cargo fmt --all -- --check` flagged two multi-line items in listeners/mod.rs that rustfmt's default policy collapses to a single line (the `ProxyV2TlvCallback` type alias and the `call_proxy_v2_tlv_callback` function signature). Both fit within the 100-column limit. --- pingora-core/src/listeners/mod.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/pingora-core/src/listeners/mod.rs b/pingora-core/src/listeners/mod.rs index 18233070..350709e4 100644 --- a/pingora-core/src/listeners/mod.rs +++ b/pingora-core/src/listeners/mod.rs @@ -140,8 +140,7 @@ fn call_client_hello_callback( /// /// Signature mirrors `ClientHelloCallback`: a pointer fn so the /// registration can stay `Sync` without an `Arc` shuffle. -pub type ProxyV2TlvCallback = - Option; +pub type ProxyV2TlvCallback = Option; /// Global callback for parsed PROXY v2 extension TLVs. Registered by /// `set_proxy_v2_tlv_callback`, invoked by `call_proxy_v2_tlv_callback`. @@ -177,10 +176,7 @@ pub fn set_proxy_v2_tlv_callback(callback: ProxyV2TlvCallback) { /// Invoke the PROXY v2 TLV callback if registered. No-op when no /// callback is set (the common case) or when the TLV list is empty. -fn call_proxy_v2_tlv_callback( - tlvs: &[proxy_protocol::ExtensionTlv], - real_addr: SocketAddr, -) { +fn call_proxy_v2_tlv_callback(tlvs: &[proxy_protocol::ExtensionTlv], real_addr: SocketAddr) { if tlvs.is_empty() { return; } From 8162b85511908c575440de7d1b70537a7efc8d44 Mon Sep 17 00:00:00 2001 From: David Papp Date: Mon, 1 Jun 2026 12:58:51 +0200 Subject: [PATCH 49/52] chore(pingora-core): pin proxy-protocol to v0.5.3 release gen0sec/proxy-protocol#12 (the `ExtensionTlv::Custom` variant this callback depends on) shipped as v0.5.3. Drop the temporary branch reference and pin to the tagged release. --- pingora-core/Cargo.toml | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/pingora-core/Cargo.toml b/pingora-core/Cargo.toml index 75e46187..15da0f11 100644 --- a/pingora-core/Cargo.toml +++ b/pingora-core/Cargo.toml @@ -73,12 +73,10 @@ x509-parser = { version = "0.16.0", optional = true } ouroboros = { version = "0.18.4", optional = true } lru = { workspace = true, optional = true } daggy = "0.8" -# Point at the gen0sec/proxy-protocol fork's feat/custom-tlv-range -# branch which adds the `ExtensionTlv::Custom { type_id, value }` -# variant required by `set_proxy_v2_tlv_callback` below. Bump back -# to `main` (or a tagged version) once -# https://github.com/gen0sec/proxy-protocol/pull/12 lands. -proxy-protocol = { git = "https://github.com/gen0sec/proxy-protocol.git", branch = "feat/custom-tlv-range" } +# Pinned to the gen0sec/proxy-protocol v0.5.3 release, which added +# the `ExtensionTlv::Custom { type_id, value }` variant required by +# `set_proxy_v2_tlv_callback` below. +proxy-protocol = { git = "https://github.com/gen0sec/proxy-protocol.git", tag = "v0.5.3" } [target.'cfg(unix)'.dependencies] daemonize = "0.5.0" nix = "~0.24.3" From 5a6eef3c02108793326232d5817250afb6b7913d Mon Sep 17 00:00:00 2001 From: David Papp Date: Mon, 1 Jun 2026 13:09:29 +0200 Subject: [PATCH 50/52] =?UTF-8?q?ci:=20bump=20MSRV=20in=20the=20build=20ma?= =?UTF-8?q?trix=201.84.0=20=E2=86=92=201.85.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `proxy-protocol >= 0.5.3` declares `edition = "2024"` in its manifest, which Cargo only stabilized in Rust 1.85.0. Building pingora-core against the v0.5.3 proxy-protocol release with rustc 1.84.0 fails at manifest parse: feature `edition2024` is required The package requires the Cargo feature called `edition2024`, but that feature is not stabilized in this version of Cargo (1.84.0 (66221abde 2024-11-19)). Bumping the MSRV pin in the build matrix to 1.85.0 picks up the stabilized edition without forcing proxy-protocol to revert to edition 2021. --- .github/workflows/build.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 22a4c458..6fb59174 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -7,8 +7,10 @@ jobs: strategy: fail-fast: false matrix: - # nightly, msrv, and latest stable - toolchain: [nightly, 1.84.0, 1.91.1] + # nightly, msrv, and latest stable. + # MSRV bumped 1.84.0 → 1.85.0 to pick up `edition = "2024"` + # stabilization, which `proxy-protocol >= 0.5.3` requires. + toolchain: [nightly, 1.85.0, 1.91.1] runs-on: ubuntu-latest # Only run on "pull_request" event for external PRs. This is to avoid # duplicate builds for PRs created from internal branches. From a8149cf1352bb45227a47ab2194308bcdfe9ca5e Mon Sep 17 00:00:00 2001 From: David Papp Date: Mon, 1 Jun 2026 13:28:49 +0200 Subject: [PATCH 51/52] fix(pingora-http): parse authority-form URIs for CONNECT request-lines MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `test_connect_proxying_allowed_h1` and `test_connect_proxying_disallowed_h1` in `pingora-proxy/tests/test_basic.rs` were failing on the fork's main branch (no relation to this PR's listener changes) with: Fail to proxy: Downstream InvalidHTTPHeader context: invalid uri pingora.org:443 The CONNECT request-line carries an authority-form URI (`pingora.org:443`, RFC 9110 § 9.3.6) rather than origin-form (`/path?query`). `RequestHeader::set_raw_path` was building the URI via `Uri::builder().path_and_query(...)`, which only accepts origin-form and rejects authority-form with `PathDoesNotStartWithSlash`. That short-circuited the request before the CONNECT-method handler in `pingora-proxy/src/lib.rs:260-274` could return 405 (for the disallowed test) or before the proxy could tunnel the bytes (for the allowed test). Fix: - `set_raw_path` tries the permissive `path_and_query` builder first (preserves the looser byte handling existing callers depend on for paths like `\`), then falls back to `Uri::try_from(...)` which auto-detects authority / absolute / asterisk forms. - `raw_path()` no longer unwraps `path_and_query()`. Authority-form URIs have no `path_and_query`; we return the authority bytes (e.g. `pingora.org:443`). Asterisk-form falls through to an empty slice rather than panicking. Verified with `cargo test -p pingora-proxy --test test_basic test_connect_proxying` — 3/3 pass after this change. --- pingora-http/src/lib.rs | 56 ++++++++++++++++++++++++++++------------- 1 file changed, 38 insertions(+), 18 deletions(-) diff --git a/pingora-http/src/lib.rs b/pingora-http/src/lib.rs index 954be81b..fb9ae979 100644 --- a/pingora-http/src/lib.rs +++ b/pingora-http/src/lib.rs @@ -256,22 +256,39 @@ impl RequestHeader { /// Generally prefer [Self::set_uri()] to modify the header's URI if able. /// /// This API is to allow supporting non UTF-8 cases. + /// + /// Accepts every URI form HTTP/1.1 request-lines can carry + /// (RFC 9112 § 3.2): origin-form (`/path?query`), absolute-form + /// (`http://host/path`), authority-form (`host:port`, used by + /// CONNECT per RFC 9110 § 9.3.6), and asterisk-form (`*`, used by + /// OPTIONS). + /// + /// Tries the permissive origin-form path + /// (`Uri::builder().path_and_query(...)`) first — that's the + /// hot path for ~all traffic and accepts looser byte sequences + /// than `Uri::try_from` does (callers had been relying on that + /// for paths containing characters like `\`). When the origin + /// builder rejects the input, falls back to `Uri::try_from`, + /// which auto-detects authority / absolute / asterisk forms so + /// CONNECT request-lines parse cleanly. Both failing surfaces + /// `InvalidHTTPHeader`. pub fn set_raw_path(&mut self, path: &[u8]) -> Result<()> { + fn parse(p: &str) -> Result { + // Permissive origin-form first — preserves the + // pre-CONNECT-fix behaviour for byte-permissive paths. + if let Ok(uri) = Uri::builder().path_and_query(p).build() { + return Ok(uri); + } + // Authority-form (CONNECT) / absolute-form / asterisk-form. + Uri::try_from(p).explain_err(InvalidHTTPHeader, |_| format!("invalid uri {}", p)) + } if let Ok(p) = std::str::from_utf8(path) { - let uri = Uri::builder() - .path_and_query(p) - .build() - .explain_err(InvalidHTTPHeader, |_| format!("invalid uri {}", p))?; - self.base.uri = uri; + self.base.uri = parse(p)?; // keep raw_path empty, no need to store twice } else { // put a valid utf-8 path into base for read only access let lossy_str = String::from_utf8_lossy(path); - let uri = Uri::builder() - .path_and_query(lossy_str.as_ref()) - .build() - .explain_err(InvalidHTTPHeader, |_| format!("invalid uri {}", lossy_str))?; - self.base.uri = uri; + self.base.uri = parse(lossy_str.as_ref())?; self.raw_path_fallback = path.to_vec(); } Ok(()) @@ -294,18 +311,21 @@ impl RequestHeader { /// Return the request path in its raw format /// /// Non-UTF8 is supported. + /// + /// For authority-form request URIs (RFC 9110 § 9.3.6 / CONNECT) + /// `Uri::path_and_query()` returns `None`; we fall back to the + /// raw authority bytes (e.g. `pingora.org:443`). For asterisk-form + /// (OPTIONS *) both `path_and_query` and `authority` are absent + /// — return an empty slice rather than panic. pub fn raw_path(&self) -> &[u8] { if !self.raw_path_fallback.is_empty() { &self.raw_path_fallback + } else if let Some(pq) = self.base.uri.path_and_query() { + pq.as_str().as_bytes() + } else if let Some(auth) = self.base.uri.authority() { + auth.as_str().as_bytes() } else { - // Url should always be set - self.base - .uri - .path_and_query() - .as_ref() - .unwrap() - .as_str() - .as_bytes() + &[] } } From abf3db62edad4f31e6b87e7996e0695a551f38e1 Mon Sep 17 00:00:00 2001 From: David Papp Date: Mon, 1 Jun 2026 13:50:26 +0200 Subject: [PATCH 52/52] fix(pingora-http): tolerate non-URI raw paths; fix doctest + test lint Follow-up to the CONNECT authority-form fix. CI bailed at the test step before reaching later stages, so these were only surfaced once the CONNECT tests passed: 1. `test_single_header` / `test_multiple_header` use `b"\\"` as a request path. `http >= 1.4` rejects paths that don't start with `/` (PathDoesNotStartWithSlash) where `http <= 1.3` accepted them, so `RequestHeader::build("GET", b"\\")` started erroring. Make `set_raw_path` preserve such bytes in `raw_path_fallback` (the mechanism the non-UTF8 branch already uses) with a `/` sentinel on `base.uri`, instead of erroring. raw_path() / the H1 wire serializer read the fallback first, so the original bytes still round-trip. 2. `set_proxy_v2_tlv_callback` doctest moved the non-Copy `real_addr` inside a loop (E0382). Borrow it instead. 3. `test_single_header_no_case` used `for_each(|_| unreachable!())`, which clippy's `never_loop` flags. Replaced with `assert!(iter.next().is_none())`. Verified on the CI toolchain (1.91.1): pingora-http --lib (8 pass), pingora-core --doc (3 pass), CONNECT tests (3 pass), and `cargo +1.91.1 clippy --all-targets --all -- --deny=warnings` clean. --- pingora-core/src/listeners/mod.rs | 2 +- pingora-http/src/lib.rs | 72 +++++++++++++++++++------------ 2 files changed, 46 insertions(+), 28 deletions(-) diff --git a/pingora-core/src/listeners/mod.rs b/pingora-core/src/listeners/mod.rs index 350709e4..9f888a47 100644 --- a/pingora-core/src/listeners/mod.rs +++ b/pingora-core/src/listeners/mod.rs @@ -162,7 +162,7 @@ static PROXY_V2_TLV_CALLBACK: std::sync::OnceLock= 1.4` rejects paths that + /// don't start with `/`, where `http <= 1.3` accepted them). + /// Rather than reject the request, preserve the raw bytes in + /// `raw_path_fallback` (the same mechanism the non-UTF8 branch + /// already uses) and set a `/` sentinel on `base.uri` so the + /// structured URI stays valid. `raw_path()` and the H1 wire + /// serializer read `raw_path_fallback` first, so the original + /// bytes still round-trip verbatim. pub fn set_raw_path(&mut self, path: &[u8]) -> Result<()> { - fn parse(p: &str) -> Result { - // Permissive origin-form first — preserves the - // pre-CONNECT-fix behaviour for byte-permissive paths. - if let Ok(uri) = Uri::builder().path_and_query(p).build() { - return Ok(uri); - } - // Authority-form (CONNECT) / absolute-form / asterisk-form. - Uri::try_from(p).explain_err(InvalidHTTPHeader, |_| format!("invalid uri {}", p)) + fn parse(p: &str) -> Option { + // origin-form first, then authority/absolute/asterisk forms. + Uri::builder() + .path_and_query(p) + .build() + .ok() + .or_else(|| Uri::try_from(p).ok()) } + // Sentinel for paths that are valid UTF-8 (or lossy-decodable) + // but not structurally valid URIs; the real bytes live in + // raw_path_fallback. + const SENTINEL: &str = "/"; if let Ok(p) = std::str::from_utf8(path) { - self.base.uri = parse(p)?; - // keep raw_path empty, no need to store twice + match parse(p) { + Some(uri) => { + self.base.uri = uri; + // keep raw_path empty, no need to store twice + } + None => { + self.base.uri = Uri::from_static(SENTINEL); + self.raw_path_fallback = path.to_vec(); + } + } } else { - // put a valid utf-8 path into base for read only access + // Non-UTF8: keep a readable base URI, preserve raw bytes. let lossy_str = String::from_utf8_lossy(path); - self.base.uri = parse(lossy_str.as_ref())?; + self.base.uri = parse(lossy_str.as_ref()).unwrap_or_else(|| Uri::from_static(SENTINEL)); self.raw_path_fallback = path.to_vec(); } Ok(()) @@ -836,9 +852,10 @@ mod tests { let mut buf: Vec = vec![]; req.header_to_h1_wire(&mut buf); assert_eq!(buf, b"foo: Bar\r\n"); - req.case_header_iter().for_each(|(_, _)| { - unreachable!("request has no case"); - }); + assert!( + req.case_header_iter().next().is_none(), + "request has no case" + ); let mut resp = ResponseHeader::new_no_case(None); resp.insert_header("foo", "bar").unwrap(); @@ -846,9 +863,10 @@ mod tests { let mut buf: Vec = vec![]; resp.header_to_h1_wire(&mut buf); assert_eq!(buf, b"foo: Bar\r\n"); - resp.case_header_iter().for_each(|(_, _)| { - unreachable!("response has no case"); - }); + assert!( + resp.case_header_iter().next().is_none(), + "response has no case" + ); } #[test]