diff --git a/.bleep b/.bleep index 64f07f84..d1911283 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -f0b43320bb1a5f7788a7d0e90a804e045f0af2fb +5281b97daa3213287999fb97c2a5de57ae565011 \ No newline at end of file diff --git a/.cargo/audit.toml b/.cargo/audit.toml new file mode 100644 index 00000000..7c6e098f --- /dev/null +++ b/.cargo/audit.toml @@ -0,0 +1,3 @@ +[advisories] +# Temp before internal sync applies dependency bumps +ignore = ["RUSTSEC-2026-0097", "RUSTSEC-2026-0098", "RUSTSEC-2026-0099"] diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 22a4c458..ccdad3cc 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -7,8 +7,10 @@ jobs: strategy: fail-fast: false matrix: - # nightly, msrv, and latest stable - toolchain: [nightly, 1.84.0, 1.91.1] + # nightly, msrv, and latest stable. + # MSRV bumped 1.84.0 → 1.85.0 to pick up `edition = "2024"` + # stabilization, which `proxy-protocol >= 0.5.3` requires. + toolchain: [nightly, 1.85.0, 1.91.1] runs-on: ubuntu-latest # Only run on "pull_request" event for external PRs. This is to avoid # duplicate builds for PRs created from internal branches. @@ -38,12 +40,17 @@ jobs: - name: Run cargo fmt run: cargo fmt --all -- --check + - name: Run cargo check + run: cargo check --workspace + - name: Run cargo test + if: matrix.toolchain != '1.84.0' run: cargo test --verbose --lib --bins --tests --no-fail-fast # Need to run doc tests separately. # (https://github.com/rust-lang/cargo/issues/6669) - name: Run cargo doc test + if: matrix.toolchain != '1.84.0' run: cargo test --verbose --doc - name: Run cargo clippy diff --git a/.github/workflows/semgrep.yml b/.github/workflows/semgrep.yml index b40314b3..3ae3dd57 100644 --- a/.github/workflows/semgrep.yml +++ b/.github/workflows/semgrep.yml @@ -1,24 +1,30 @@ +name: Semgrep OSS scan on: pull_request: {} + push: + branches: [main, master] workflow_dispatch: {} - push: - branches: - - main - - master schedule: - - cron: '0 0 * * *' -name: Semgrep config + - cron: '0 0 15 * *' +concurrency: + group: semgrep-${{ github.event_name }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true +permissions: + contents: read jobs: semgrep: - name: semgrep/ci - runs-on: ubuntu-latest - env: - SEMGREP_APP_TOKEN: ${{ secrets.SEMGREP_APP_TOKEN }} - SEMGREP_URL: https://cloudflare.semgrep.dev - SEMGREP_APP_URL: https://cloudflare.semgrep.dev - SEMGREP_VERSION_CHECK_URL: https://cloudflare.semgrep.dev/api/check-version - container: - image: returntocorp/semgrep + name: semgrep-oss + runs-on: ubuntu-slim steps: - - uses: actions/checkout@v4 - - run: semgrep ci + - uses: actions/checkout@v5 + with: + fetch-depth: 1 + - id: cache-semgrep + uses: actions/cache@v5 + with: + path: ~/.local + key: semgrep-1.160.0-${{ runner.os }} + - if: steps.cache-semgrep.outputs.cache-hit != 'true' + run: pip install --user semgrep==1.160.0 + - run: echo "$HOME/.local/bin" >> "$GITHUB_PATH" + - run: semgrep scan --config=auto diff --git a/.gitignore b/.gitignore index abc8cf51..a8fa88f9 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ dhat-heap.json .vscode .idea .cover -bleeper.user.toml \ No newline at end of file +bleeper.user.toml +.DS_Store diff --git a/Cargo.toml b/Cargo.toml index d3c8603b..c78de1f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ members = [ "pingora-ketama", "pingora-load-balancing", "pingora-memory-cache", + "pingora-prometheus", "tinyufo", ] diff --git a/docs/user_guide/conf.md b/docs/user_guide/conf.md index 1f55859e..837dc519 100644 --- a/docs/user_guide/conf.md +++ b/docs/user_guide/conf.md @@ -23,12 +23,16 @@ group: webusers | threads | number of threads per service | number | | user | the user the pingora server should be run under after daemonization | string | | group | the group the pingora server should be run under after daemonization | string | +| working_directory | the working directory for the daemonized process | string | | client_bind_to_ipv4 | source IPv4 addresses to bind to when connecting to server | list of string | | client_bind_to_ipv6 | source IPv6 addresses to bind to when connecting to server| list of string | | ca_file | The path to the root CA file | string | | s2n_config_cache_size | The maximum number of unique s2n configs to cache. A value of 0 disables the cache. Default: 10 (s2n-tls only) | number | | work_stealing | Enable work stealing runtime (default true). See Pingora runtime (WIP) section for more info | bool | | upstream_keepalive_pool_size | The number of total connections to keep in the connection pool | number | +| daemon_wait_for_ready | When `true` and `daemon` is `true`, the parent process waits for the daemon to signal readiness (via `SIGUSR1`) before exiting. This causes systemd to delay sending `SIGQUIT` to the old process until the new instance is fully bootstrapped. Default: `false` | bool | +| daemon_ready_timeout_seconds | How long (in seconds) the parent waits for the daemon to signal readiness when `daemon_wait_for_ready` is `true`. If the daemon does not signal in time the parent exits with a non-zero code, causing systemd to abort the reload. Default: `600` | number | +| daemon_notify_timeout_seconds | How long (in seconds) the daemon retries sending `SIGUSR1` to the parent when the attempt fails with a permission error. This covers the brief window after the fork where the parent has not yet dropped its UID to match the daemon. Default: `60` | number | ## Extension Any unknown settings will be ignored. This allows extending the conf file to add and pass user defined settings. See User defined configuration section. diff --git a/docs/user_guide/modify_filter.md b/docs/user_guide/modify_filter.md index 3e5378fb..a833fc27 100644 --- a/docs/user_guide/modify_filter.md +++ b/docs/user_guide/modify_filter.md @@ -123,8 +123,7 @@ impl ProxyHttp for MyGateway { fn main() { ... - let mut prometheus_service_http = - pingora::services::listening::Service::prometheus_http_service(); + let mut prometheus_service_http = pingora_prometheus::prometheus_http_service(); prometheus_service_http.add_tcp("127.0.0.1:6192"); my_server.add_service(prometheus_service_http); diff --git a/docs/user_guide/phase.md b/docs/user_guide/phase.md index 3c80f913..5f3a891a 100644 --- a/docs/user_guide/phase.md +++ b/docs/user_guide/phase.md @@ -29,7 +29,8 @@ Pingora-proxy allows users to insert arbitrary logic into the life of a request. upstream_request_filter --> request_body_filter; request_body_filter --> SendReq{{IO: send request to upstream}}; SendReq-->RecvResp{{IO: read response from upstream}}; - RecvResp-->upstream_response_filter-->response_filter-->upstream_response_body_filter-->response_body_filter-->logging-->endreq("request done"); + RecvResp-.feature: adjust_upstream_modules.->adjust_upstream_modules; + adjust_upstream_modules-->upstream_response_filter-->response_filter-->upstream_response_body_filter-->response_body_filter-->logging-->endreq("request done"); fail_to_connect --can retry-->upstream_peer; fail_to_connect --can't retry-->fail_to_proxy--send error response-->logging; @@ -92,6 +93,11 @@ If the error is not retry-able, the request will end. ### `upstream_request_filter()` This phase is to modify requests before sending to upstream. +### `adjust_upstream_modules()` _(feature: `adjust_upstream_modules`)_ +This phase is triggered when the upstream response header arrives, before upstream modules (such as `upstream_compression`) process it. + +Use this to configure upstream module behavior based on the response header, e.g. setting a dictionary for dictionary-based content encoding. The response header is provided as an immutable reference; to modify the response header itself, use `upstream_response_filter()` instead. + ### `upstream_response_filter()/upstream_response_body_filter()/upstream_response_trailer_filter()` This phase is triggered after an upstream response header/body/trailer is received. diff --git a/docs/user_guide/phase_chart.md b/docs/user_guide/phase_chart.md index 94988724..b915f950 100644 --- a/docs/user_guide/phase_chart.md +++ b/docs/user_guide/phase_chart.md @@ -14,7 +14,8 @@ Pingora proxy phases without caching upstream_request_filter --> request_body_filter; request_body_filter --> SendReq{{IO: send request to upstream}}; SendReq-->RecvResp{{IO: read response from upstream}}; - RecvResp-->upstream_response_filter-->response_filter-->upstream_response_body_filter-->response_body_filter-->logging-->endreq("request done"); + RecvResp-.feature: adjust_upstream_modules.->adjust_upstream_modules; + adjust_upstream_modules-->upstream_response_filter-->response_filter-->upstream_response_body_filter-->response_body_filter-->logging-->endreq("request done"); fail_to_connect --can retry-->upstream_peer; fail_to_connect --can't retry-->fail_to_proxy--send error response-->logging; diff --git a/docs/user_guide/prom.md b/docs/user_guide/prom.md index b1868f12..1e83c0f3 100644 --- a/docs/user_guide/prom.md +++ b/docs/user_guide/prom.md @@ -1,29 +1,21 @@ # Prometheus -Pingora has a built-in prometheus HTTP metric server for scraping. +The [`pingora-prometheus`](https://docs.rs/pingora-prometheus) crate provides a +Prometheus HTTP metrics server for scraping. -## Enabling Prometheus Support +## Adding the Dependency -Prometheus support is an optional feature in Pingora. To use it, you need to enable the `prometheus` feature in your `Cargo.toml`: +Add `pingora-prometheus` to your `Cargo.toml`: ```toml -# If using the main pingora crate -pingora = { version = "0.8.0", features = ["prometheus"] } - -# If using pingora-core directly -pingora-core = { version = "0.8.0", features = ["prometheus"] } - -# If using pingora-proxy crate -pingora-proxy = { version = "0.8.0", features = ["prometheus"] } +pingora-prometheus = "0.8.0" ``` ## Setting up a Prometheus Metrics Endpoint -Once the feature is enabled, you can set up a Prometheus metrics endpoint like this: - ```rust ... - let mut prometheus_service_http = Service::prometheus_http_service(); + let mut prometheus_service_http = pingora_prometheus::prometheus_http_service(); prometheus_service_http.add_tcp("0.0.0.0:1234"); my_server.add_service(prometheus_service_http); my_server.run_forever(); diff --git a/pingora-cache/Cargo.toml b/pingora-cache/Cargo.toml index 401d827c..44a063ef 100644 --- a/pingora-cache/Cargo.toml +++ b/pingora-cache/Cargo.toml @@ -37,8 +37,8 @@ httpdate = "1.0.2" log = { workspace = true } async-trait = { workspace = true } parking_lot = "0.12" -cf-rustracing = "1.0" -cf-rustracing-jaeger = "1.0" +cf-rustracing = { version = "1.0", optional = true } +cf-rustracing-jaeger = { version = "1.0", optional = true } rmp = "0.8.14" tokio = { workspace = true } lru = { workspace = true } @@ -49,6 +49,8 @@ strum = { version = "0.26", features = ["derive"] } rand = "0.8" [dev-dependencies] +cf-rustracing = "1.0" +cf-rustracing-jaeger = "1.0" tokio-test = "0.4" tokio = { workspace = true, features = ["fs"] } env_logger = "0.11" @@ -73,3 +75,4 @@ openssl = ["pingora-core/openssl"] boringssl = ["pingora-core/boringssl"] rustls = ["pingora-core/rustls"] s2n = ["pingora-core/s2n"] +trace = ["dep:cf-rustracing", "dep:cf-rustracing-jaeger"] diff --git a/pingora-cache/src/cache_control.rs b/pingora-cache/src/cache_control.rs index 98af7fbb..f8612080 100644 --- a/pingora-cache/src/cache_control.rs +++ b/pingora-cache/src/cache_control.rs @@ -92,6 +92,51 @@ impl DirectiveValue { } } } + + /// Parse the [DirectiveValue] as delta seconds, permitting a fractional component. + /// + /// Values with a fractional part are rounded down (floored) to the nearest + /// non-negative integer: e.g. `1.9` -> `1`, `0.5` -> `0`. This is useful + /// for compatibility with upstreams that emit fractional ttls (which + /// RFC 9111 strict integer parsing rejects). + /// + /// Integer parsing is attempted first, so strictly-numeric values (including + /// overflow-capped values and quoted integers) behave identically to + /// [Self::parse_as_delta_seconds]. Negative, non-finite, or non-numeric + /// values still return an error. + /// + /// `"`s are ignored. The value is capped to [DELTA_SECONDS_OVERFLOW_VALUE]. + pub fn parse_as_delta_seconds_floor(&self) -> Result { + // UTF-8 validate once; on non-UTF8 input, propagate the same error as + // [Self::parse_as_delta_seconds]. + let s = self.parse_as_str()?; + match s.parse::() { + Ok(value) => Ok(value), + Err(e) if e.kind() == &IntErrorKind::PosOverflow => Ok(DELTA_SECONDS_OVERFLOW_VALUE), + Err(int_err) => { + // Fall back to parsing as a non-negative finite float and floor. + // On any failure, return an error equivalent to the strict + // [Self::parse_as_delta_seconds] u32-parse error. + match s.parse::() { + Ok(f) if f.is_finite() && f >= 0.0 => { + if f >= DELTA_SECONDS_OVERFLOW_VALUE as f64 { + Ok(DELTA_SECONDS_OVERFLOW_VALUE) + } else { + // Safe cast: `f` is finite, non-negative, and strictly + // less than `DELTA_SECONDS_OVERFLOW_VALUE` (i32::MAX), + // which fits in `u32` after flooring. + Ok(f.floor() as u32) + } + } + _ => Error::e_because( + ErrorType::InternalError, + "could not parse value as u32", + int_err, + ), + } + } + } + } } /// An ordered map to store cache control key value pairs. @@ -102,6 +147,15 @@ pub type DirectiveMap = IndexMap>; pub struct CacheControl { /// The parsed directives pub directives: DirectiveMap, + /// When set, delta-seconds directives (`max-age`, `s-maxage`, + /// `stale-while-revalidate`, `stale-if-error`) accept fractional values + /// and round them down to the nearest non-negative integer. + /// + /// Defaults to `false`, matching RFC 9111 strict integer parsing. Enable + /// via [CacheControl::with_float_seconds] (or by assigning the field + /// directly) for contexts that need to interoperate with upstreams that + /// emit fractional ttls. + pub allow_float_seconds: bool, } /// Cacheability calculated from cache control. @@ -198,7 +252,20 @@ impl CacheControl { directives.insert(key.unwrap(), value); } } - Some(CacheControl { directives }) + Some(CacheControl { + directives, + allow_float_seconds: false, + }) + } + + /// Builder setter: enable fractional delta-seconds parsing. + /// + /// See [CacheControl::allow_float_seconds] for semantics. Returns `self` + /// so it can be chained onto a parser call, e.g. + /// `CacheControl::from_resp_headers(&resp).map(|cc| cc.with_float_seconds())`. + pub fn with_float_seconds(mut self) -> Self { + self.allow_float_seconds = true; + self } /// Parse from the given header name in `headers` @@ -282,7 +349,12 @@ impl CacheControl { fn parse_delta_seconds(&self, key: &str) -> Result> { if let Some(Some(dir_value)) = self.directives.get(key) { - Ok(Some(dir_value.parse_as_delta_seconds()?)) + let value = if self.allow_float_seconds { + dir_value.parse_as_delta_seconds_floor()? + } else { + dir_value.parse_as_delta_seconds()? + }; + Ok(Some(value)) } else { Ok(None) } @@ -869,4 +941,168 @@ mod tests { let cc = CacheControl::from_req_headers(&req).unwrap(); assert!(cc.only_if_cached()) } + + #[test] + fn test_parse_as_delta_seconds_floor() { + // Integer values behave identically to the strict parser + let v = DirectiveValue(b"10".to_vec()); + assert_eq!(v.parse_as_delta_seconds_floor().unwrap(), 10); + + let v = DirectiveValue(b"\"10\"".to_vec()); + assert_eq!(v.parse_as_delta_seconds_floor().unwrap(), 10); + + // Quoted fractional values are unwrapped by parse_as_str and floored. + let v = DirectiveValue(b"\"1.5\"".to_vec()); + assert_eq!(v.parse_as_delta_seconds_floor().unwrap(), 1); + + let v = DirectiveValue(b"0".to_vec()); + assert_eq!(v.parse_as_delta_seconds_floor().unwrap(), 0); + + // Integer positive overflow is still capped + let v = DirectiveValue(b"99999999999999999999".to_vec()); + assert_eq!( + v.parse_as_delta_seconds_floor().unwrap(), + DELTA_SECONDS_OVERFLOW_VALUE + ); + + // Floats are floored + let v = DirectiveValue(b"1.5".to_vec()); + assert_eq!(v.parse_as_delta_seconds_floor().unwrap(), 1); + + let v = DirectiveValue(b"1.9".to_vec()); + assert_eq!(v.parse_as_delta_seconds_floor().unwrap(), 1); + + let v = DirectiveValue(b"0.5".to_vec()); + assert_eq!(v.parse_as_delta_seconds_floor().unwrap(), 0); + + let v = DirectiveValue(b"3600.0".to_vec()); + assert_eq!(v.parse_as_delta_seconds_floor().unwrap(), 3600); + + // Float positive overflow is capped + let v = DirectiveValue(b"99999999999.5".to_vec()); + assert_eq!( + v.parse_as_delta_seconds_floor().unwrap(), + DELTA_SECONDS_OVERFLOW_VALUE + ); + + // Negative values are rejected (matches strict behavior) + assert!(DirectiveValue(b"-1".to_vec()) + .parse_as_delta_seconds_floor() + .is_err()); + assert!(DirectiveValue(b"-1.5".to_vec()) + .parse_as_delta_seconds_floor() + .is_err()); + + // Non-finite / non-numeric values are rejected + assert!(DirectiveValue(b"NaN".to_vec()) + .parse_as_delta_seconds_floor() + .is_err()); + assert!(DirectiveValue(b"inf".to_vec()) + .parse_as_delta_seconds_floor() + .is_err()); + assert!(DirectiveValue(b"abc".to_vec()) + .parse_as_delta_seconds_floor() + .is_err()); + + // Non-UTF8 bytes are rejected with the same utf-8 error as the strict parser. + let v = DirectiveValue(b"ba\xFFr".to_vec()); + assert_eq!( + v.parse_as_delta_seconds_floor() + .unwrap_err() + .context + .unwrap() + .to_string(), + "could not parse value as utf8", + ); + } + + #[test] + fn test_cache_control_allow_float_seconds_non_utf8_value() { + // Non-UTF8 bytes inside `max-age` should still produce the utf-8 error + // when the float-permitting flag is on, matching the strict parser. + let mut resp = response::Builder::new().body(()).unwrap(); + resp.headers_mut().insert( + CACHE_CONTROL, + HeaderValue::from_bytes(b"max-age=ba\xFFr").unwrap(), + ); + let (parts, _) = resp.into_parts(); + let cc = CacheControl::from_resp_headers(&parts) + .unwrap() + .with_float_seconds(); + assert_eq!( + cc.max_age().unwrap_err().context.unwrap().to_string(), + "could not parse value as utf8", + ); + } + + #[test] + fn test_cache_control_allow_float_seconds_default_off() { + // Default (strict) parsing: fractional values produce an error, and + // [InterpretCacheControl::fresh_duration] returns None, matching the + // pre-existing behavior. + let resp = build_response(CACHE_CONTROL, "max-age=10.7"); + let cc = CacheControl::from_resp_headers(&resp).unwrap(); + assert!(!cc.allow_float_seconds); + assert!(cc.max_age().is_err()); + assert!(cc.fresh_duration().is_none()); + } + + #[test] + fn test_cache_control_with_float_seconds() { + // `max-age` with a fractional value is floored when the flag is on. + let resp = build_response(CACHE_CONTROL, "max-age=10.7"); + let cc = CacheControl::from_resp_headers(&resp) + .unwrap() + .with_float_seconds(); + assert!(cc.allow_float_seconds); + assert_eq!(cc.max_age().unwrap().unwrap(), 10); + assert_eq!(cc.fresh_duration().unwrap(), Duration::from_secs(10)); + + // `s-maxage` still wins over `max-age` and is also floored. + let resp = build_response(CACHE_CONTROL, "s-maxage=3600.99, max-age=1800"); + let cc = CacheControl::from_resp_headers(&resp) + .unwrap() + .with_float_seconds(); + assert_eq!(cc.s_maxage().unwrap().unwrap(), 3600); + assert_eq!(cc.fresh_duration().unwrap(), Duration::from_secs(3600)); + + // `stale-while-revalidate` and `stale-if-error` also pick up flooring. + let resp = build_response( + CACHE_CONTROL, + "max-age=10, stale-while-revalidate=60.5, stale-if-error=30.9", + ); + let cc = CacheControl::from_resp_headers(&resp) + .unwrap() + .with_float_seconds(); + assert_eq!(cc.stale_while_revalidate().unwrap().unwrap(), 60); + assert_eq!(cc.stale_if_error().unwrap().unwrap(), 30); + assert_eq!( + cc.serve_stale_while_revalidate_duration().unwrap(), + Duration::from_secs(60) + ); + assert_eq!( + cc.serve_stale_if_error_duration().unwrap(), + Duration::from_secs(30) + ); + + // Integer values are unaffected when the flag is on. + let resp = build_response(CACHE_CONTROL, "max-age=12345"); + let cc = CacheControl::from_resp_headers(&resp) + .unwrap() + .with_float_seconds(); + assert_eq!(cc.fresh_duration().unwrap(), Duration::from_secs(12345)); + + // Invalid (non-numeric, negative) values still fail to parse under the flag. + let resp = build_response(CACHE_CONTROL, "max-age=abc"); + let cc = CacheControl::from_resp_headers(&resp) + .unwrap() + .with_float_seconds(); + assert!(cc.max_age().is_err()); + + let resp = build_response(CACHE_CONTROL, "max-age=-1.5"); + let cc = CacheControl::from_resp_headers(&resp) + .unwrap() + .with_float_seconds(); + assert!(cc.max_age().is_err()); + } } diff --git a/pingora-cache/src/eviction/lru.rs b/pingora-cache/src/eviction/lru.rs index d241ee69..f7a03f99 100644 --- a/pingora-cache/src/eviction/lru.rs +++ b/pingora-cache/src/eviction/lru.rs @@ -72,7 +72,9 @@ impl Manager { self.0.shard_weight(shard) } - /// Get the number of items in a specific shard + /// Get the number of items in a specific shard. Best-effort + /// lock-free read; see [`pingora_lru::Lru::shard_len`] for the + /// consistency semantics. pub fn shard_len(&self, shard: usize) -> usize { self.0.shard_len(shard) } @@ -85,6 +87,15 @@ impl Manager { (u64key(key) % N as u64) as usize } + /// Peek at the least-recently-used key in the given shard without evicting it. + /// + /// Returns the cache key at the LRU tail of the shard, or `None` if empty. + /// Useful for reporting the eviction frontier (the age of the next item + /// that would be evicted). + pub fn peek_lru(&self, shard: usize) -> Option { + self.0.peek_lru(shard).map(|(key, _weight)| key) + } + /// Serialize the given shard pub fn serialize_shard(&self, shard: usize) -> Result> { use rmp_serde::encode::Serializer; @@ -614,4 +625,42 @@ mod test { // Cleanup test directory std::fs::remove_dir_all(dir_path).unwrap(); } + + #[test] + fn test_peek_lru() { + let lru = Manager::<1>::with_capacity(20, 20); + let until = SystemTime::now(); + + // empty shard returns None + assert!(lru.peek_lru(0).is_none()); + + let key1 = CacheKey::new("", "a", "1").to_compact(); + lru.admit(key1.clone(), 1, until); + // single item: it's both the head and the tail + assert_eq!(lru.peek_lru(0).unwrap(), key1); + + // admit more keys to push key1 to the tail + let key2 = CacheKey::new("", "b", "1").to_compact(); + lru.admit(key2.clone(), 1, until); + for i in 0..5 { + lru.admit( + CacheKey::new("", format!("f{i}"), "1").to_compact(), + 1, + until, + ); + } + // key1 is the LRU tail (admitted first) + assert_eq!(lru.peek_lru(0).unwrap(), key1); + + // promote key1 — now key2 becomes the tail + lru.access(&key1, 1, until); + assert_eq!(lru.peek_lru(0).unwrap(), key2); + + // peek_lru should not remove the item + assert_eq!(lru.peek_lru(0).unwrap(), key2); + assert!(lru.peek(&key2)); + + // out-of-bounds shard returns None + assert!(lru.peek_lru(999).is_none()); + } } diff --git a/pingora-cache/src/lib.rs b/pingora-cache/src/lib.rs index 867cff08..6d13409b 100644 --- a/pingora-cache/src/lib.rs +++ b/pingora-cache/src/lib.rs @@ -16,7 +16,6 @@ #![allow(clippy::new_without_default)] -use cf_rustracing::tag::Tag; use http::{method::Method, request::Parts as ReqHeader, response::Parts as RespHeader}; use key::{CacheHashKey, CompactCacheKey, HashBinary}; use lock::WritePermit; @@ -27,7 +26,7 @@ use pingora_timeout::timeout; use std::time::{Duration, Instant, SystemTime}; use storage::MissFinishType; use strum::IntoStaticStr; -use trace::{CacheTraceCTX, Span}; +use trace::{CacheTraceCTX, Span, Tag}; pub mod cache_control; pub mod eviction; @@ -435,6 +434,7 @@ impl HttpCache { self.phase = CachePhase::Disabled(reason); self.release_write_lock(reason); // enabled_ctx will be cleared out + #[cfg_attr(not(feature = "trace"), allow(unused_mut))] let mut inner_enabled = self .inner_mut() .enabled_ctx @@ -1049,6 +1049,7 @@ impl HttpCache { inner_enabled.meta.replace(meta); + #[cfg_attr(not(feature = "trace"), allow(unused_mut))] let mut span = inner_enabled.traces.child("update_meta"); let result = inner_enabled .storage @@ -1291,6 +1292,7 @@ impl HttpCache { .enabled_ctx .as_mut() .expect("Cache enabled on cache_lookup"); + #[cfg_attr(not(feature = "trace"), allow(unused_mut))] let mut span = inner_enabled.traces.child("lookup"); let key = inner.key.as_ref().unwrap(); // safe, this phase should have cache key let now = Instant::now(); @@ -1446,6 +1448,7 @@ impl HttpCache { /// Check [Self::is_cache_locked()], panic if this request doesn't have a read lock. pub async fn cache_lock_wait(&mut self) -> LockStatus { let inner_enabled = self.inner_enabled_mut(); + #[cfg_attr(not(feature = "trace"), allow(unused_mut))] let mut span = inner_enabled.traces.child("cache_lock"); // should always call is_cache_locked() before this function, which should guarantee that // the inner cache has a read lock and lock ctx @@ -1535,6 +1538,7 @@ impl HttpCache { }) } + #[cfg_attr(not(feature = "trace"), allow(unused_mut))] async fn purge_impl( storage: &'static (dyn storage::Storage + Sync), eviction: Option<&'static (dyn eviction::EvictionManager + Sync)>, diff --git a/pingora-cache/src/lock.rs b/pingora-cache/src/lock.rs index 5633b09c..102b3380 100644 --- a/pingora-cache/src/lock.rs +++ b/pingora-cache/src/lock.rs @@ -14,8 +14,8 @@ //! Cache lock +use crate::trace::{Span, Tag}; use crate::{hashtable::ConcurrentHashTable, key::CacheHashKey, CacheKey}; -use crate::{Span, Tag}; use http::Extensions; use pingora_timeout::timeout; diff --git a/pingora-cache/src/memory.rs b/pingora-cache/src/memory.rs index 6ab57c80..e6e12acf 100644 --- a/pingora-cache/src/memory.rs +++ b/pingora-cache/src/memory.rs @@ -426,7 +426,7 @@ impl Storage for MemCache { #[cfg(test)] mod test { use super::*; - use cf_rustracing::span::Span; + use crate::trace::Span; use once_cell::sync::Lazy; fn gen_meta() -> CacheMeta { diff --git a/pingora-cache/src/put.rs b/pingora-cache/src/put.rs index fbbbb70e..94265055 100644 --- a/pingora-cache/src/put.rs +++ b/pingora-cache/src/put.rs @@ -84,6 +84,7 @@ impl CachePutCtx { } async fn put_header(&mut self, meta: CacheMeta) -> Result<()> { + #[cfg_attr(not(feature = "trace"), allow(unused_mut))] let mut trace = self.trace.child("cache put header", |o| o.start()); let miss_handler = self .storage @@ -239,7 +240,7 @@ impl CachePutCtx { #[cfg(test)] mod test { use super::*; - use cf_rustracing::span::Span; + use crate::trace::Span; use once_cell::sync::Lazy; struct TestCachePut(); diff --git a/pingora-cache/src/trace.rs b/pingora-cache/src/trace.rs index f27929a2..e8ab85cf 100644 --- a/pingora-cache/src/trace.rs +++ b/pingora-cache/src/trace.rs @@ -13,26 +13,129 @@ // limitations under the License. //! Distributed tracing helpers +//! +//! When the `trace` feature is enabled, this module re-exports the real +//! [`cf_rustracing`]/[`cf_rustracing_jaeger`] span types. +//! +//! When the `trace` feature is **disabled**, lightweight no-op shim types are +//! provided instead so that the rest of the crate compiles without pulling in +//! the tracing dependencies. -use cf_rustracing_jaeger::span::SpanContextState; use std::time::SystemTime; use crate::{CacheMeta, CachePhase, HitStatus}; -pub use cf_rustracing::tag::Tag; +// --------------------------------------------------------------------------- +// Real tracing implementation (feature = "trace") +// --------------------------------------------------------------------------- +#[cfg(feature = "trace")] +mod real { + pub use cf_rustracing::tag::Tag; -pub type Span = cf_rustracing::span::Span; -pub type SpanHandle = cf_rustracing::span::SpanHandle; + use cf_rustracing_jaeger::span::SpanContextState; -#[derive(Debug)] -pub(crate) struct CacheTraceCTX { - // parent span - pub cache_span: Span, - // only spans across multiple calls need to store here - pub miss_span: Span, - pub hit_span: Span, + pub type Span = cf_rustracing::span::Span; + pub type SpanHandle = cf_rustracing::span::SpanHandle; } +#[cfg(feature = "trace")] +pub use real::*; + +// --------------------------------------------------------------------------- +// No-op shim types (feature = "trace" disabled) +// --------------------------------------------------------------------------- +#[cfg(not(feature = "trace"))] +mod noop { + /// A no-op replacement for [`cf_rustracing::tag::Tag`]. + #[derive(Debug)] + pub struct Tag { + _priv: (), + } + + impl Tag { + /// Create a no-op tag. All arguments are ignored. + #[inline] + pub fn new(_name: N, _value: V) -> Self { + Tag { _priv: () } + } + } + + /// A no-op replacement for a rustracing `Span`. + #[derive(Debug)] + pub struct Span { + _priv: (), + } + + impl Span { + /// Return an inactive (no-op) span. + #[inline] + pub fn inactive() -> Self { + Span { _priv: () } + } + + /// Return a no-op handle. + #[inline] + pub fn handle(&self) -> SpanHandle { + SpanHandle { _priv: () } + } + + /// No-op: create a child span. + #[inline] + pub fn child(&self, _name: &'static str, _f: F) -> Span + where + F: FnOnce(SpanOptionsPlaceholder) -> SpanOptionsPlaceholder, + { + Span::inactive() + } + + /// No-op: set a single tag via a closure. + #[inline] + pub fn set_tag Tag>(&self, _f: F) {} + + /// No-op: set multiple tags via a closure. + #[inline] + pub fn set_tags(&self, _f: F) + where + F: FnOnce() -> I, + I: IntoIterator, + { + } + + /// No-op: set a finish time. + #[inline] + pub fn set_finish_time std::time::SystemTime>(&self, _f: F) {} + } + + /// Placeholder type used in [`Span::child`] closure signatures so that + /// existing call-sites like `span.child("name", |o| o.start())` compile. + #[doc(hidden)] + pub struct SpanOptionsPlaceholder { + _priv: (), + } + + impl SpanOptionsPlaceholder { + /// No-op: mirrors `SpanOptions::start()`. + #[inline] + pub fn start(self) -> Self { + self + } + } + + /// A no-op replacement for a rustracing `SpanHandle`. + #[derive(Debug)] + pub struct SpanHandle { + _priv: (), + } +} + +#[cfg(not(feature = "trace"))] +pub use noop::*; + +// --------------------------------------------------------------------------- +// Shared helpers (work with both real and no-op types) +// --------------------------------------------------------------------------- + +/// Tag a span with metadata from a [`CacheMeta`]. pub fn tag_span_with_meta(span: &mut Span, meta: &CacheMeta) { fn ts2epoch(ts: SystemTime) -> f64 { ts.duration_since(SystemTime::UNIX_EPOCH) @@ -55,6 +158,15 @@ pub fn tag_span_with_meta(span: &mut Span, meta: &CacheMeta) { }); } +#[derive(Debug)] +pub(crate) struct CacheTraceCTX { + // parent span + pub cache_span: Span, + // only spans across multiple calls need to store here + pub miss_span: Span, + pub hit_span: Span, +} + impl CacheTraceCTX { pub fn new() -> Self { CacheTraceCTX { diff --git a/pingora-core/Cargo.toml b/pingora-core/Cargo.toml index 5bbe983b..bcad7acb 100644 --- a/pingora-core/Cargo.toml +++ b/pingora-core/Cargo.toml @@ -47,7 +47,6 @@ strum = "0.26.2" strum_macros = "0.26.2" libc = "0.2.70" chrono = { version = "~0.4.31", features = ["alloc"], default-features = false } -prometheus = { version = "0.14", optional = true } sentry = { version = "0.36", features = [ "backtrace", "contexts", @@ -73,9 +72,13 @@ x509-parser = { version = "0.16.0", optional = true } ouroboros = { version = "0.18.4", optional = true } lru = { workspace = true, optional = true } daggy = "0.8" -proxy-protocol = {git = "https://github.com/arxignis/proxy-protocol.git"} +# Pinned to the gen0sec/proxy-protocol v0.5.3 release, which added +# the `ExtensionTlv::Custom { type_id, value }` variant required by +# `set_proxy_v2_tlv_callback` below. +proxy-protocol = { git = "https://github.com/gen0sec/proxy-protocol.git", tag = "v0.5.3" } [target.'cfg(unix)'.dependencies] daemonize = "0.5.0" +flurry = "0.5" nix = "~0.24.3" [target.'cfg(windows)'.dependencies] @@ -85,15 +88,18 @@ windows-sys = { version = "0.59.0", features = ["Win32_Networking_WinSock"] } h2 = { workspace = true, features = ["unstable"] } tokio-stream = { version = "0.1", features = ["full"] } env_logger = "0.11" -reqwest = { version = "0.11", features = [ +reqwest = { version = "0.12", features = [ "rustls-tls", + "http2", ], default-features = false } -hyper = "0.14" +hyper = { version = "1", features = ["client", "http1", "http2"] } +hyper-util = { version = "0.1", features = ["client-legacy", "http1", "http2"] } +http-body-util = "0.1" rstest = "0.23.0" rustls = "0.23" [target.'cfg(unix)'.dev-dependencies] -hyperlocal = "0.8" +hyperlocal = "0.9" jemallocator = "0.5" [features] @@ -107,4 +113,3 @@ openssl_derived = ["any_tls"] any_tls = [] sentry = ["dep:sentry"] connection_filter = [] -prometheus = ["dep:prometheus"] diff --git a/pingora-core/src/apps/mod.rs b/pingora-core/src/apps/mod.rs index 8c087489..93bea8b4 100644 --- a/pingora-core/src/apps/mod.rs +++ b/pingora-core/src/apps/mod.rs @@ -15,14 +15,11 @@ //! The abstraction and implementation interface for service application logic pub mod http_app; -#[cfg(feature = "prometheus")] -pub mod prometheus_http_app; use crate::server::ShutdownWatch; use async_trait::async_trait; use log::{debug, error}; use std::any::Any; -use std::future::poll_fn; use std::sync::Arc; use crate::protocols::http::v2::server; @@ -60,7 +57,7 @@ pub trait ServerApp { async fn cleanup(&self) {} } #[non_exhaustive] -#[derive(Default)] +#[derive(Clone, Debug, Default)] /// HTTP Server options that control how the server handles some transport types. pub struct HttpServerOptions { /// Allow HTTP/2 for plaintext. @@ -252,8 +249,7 @@ where }); let h2_options = self.h2_options(); - let h2_conn = server::handshake(stream, h2_options).await; - let mut h2_conn = match h2_conn { + let h2_conn = match server::handshake(stream, h2_options).await { Err(e) => { error!("H2 handshake error {e}"); return None; @@ -261,36 +257,21 @@ where Ok(c) => c, }; - let mut shutdown = shutdown.clone(); - loop { - // this loop ends when the client decides to close the h2 conn - // TODO: add a timeout? - let h2_stream = tokio::select! { - _ = shutdown.changed() => { - h2_conn.graceful_shutdown(); - let _ = poll_fn(|cx| h2_conn.poll_closed(cx)) - .await.map_err(|e| error!("H2 error waiting for shutdown {e}")); - return None; - } - h2_stream = server::HttpSession::from_h2_conn(&mut h2_conn, digest.clone()) => h2_stream - }; - let h2_stream = match h2_stream { - Err(e) => { - // It is common for the client to just disconnect TCP without properly - // closing H2. So we don't log the errors here - debug!("H2 error when accepting new stream {e}"); - return None; - } - Ok(s) => s?, // None means the connection is ready to be closed - }; - let app = self.clone(); - let shutdown = shutdown.clone(); + // The accept-loop body — including the graceful-shutdown state + // machine — lives in `server::accept_downstream_sessions` so that + // the same code path is exercised by tests in `protocols::http::v2`. + let app = self.clone(); + let shutdown_for_session = shutdown.clone(); + server::accept_downstream_sessions(h2_conn, digest, shutdown.clone(), |h2_stream| { + let app = app.clone(); + let shutdown = shutdown_for_session.clone(); pingora_runtime::current_handle().spawn(async move { // Note, `PersistentSettings` not currently relevant for h2 app.process_new_http(ServerSession::new_http2(h2_stream), &shutdown) .await; }); - } + }) + .await; } else if custom || matches!(stream.selected_alpn_proto(), Some(ALPN::Custom(_))) { return self.clone().process_custom_session(stream, shutdown).await; } else { diff --git a/pingora-core/src/apps/prometheus_http_app.rs b/pingora-core/src/apps/prometheus_http_app.rs deleted file mode 100644 index f06cce7d..00000000 --- a/pingora-core/src/apps/prometheus_http_app.rs +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2026 Cloudflare, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! An HTTP application that reports Prometheus metrics. - -#[cfg(feature = "prometheus")] -mod prometheus_impl { - use async_trait::async_trait; - use http::Response; - use prometheus::{Encoder, TextEncoder}; - - use super::super::http_app::HttpServer; - use crate::apps::http_app::ServeHttp; - use crate::modules::http::compression::ResponseCompressionBuilder; - use crate::protocols::http::ServerSession; - - /// An HTTP application that reports Prometheus metrics. - /// - /// This application will report all the [static metrics](https://docs.rs/prometheus/latest/prometheus/index.html#static-metrics) - /// collected via the [Prometheus](https://docs.rs/prometheus/) crate; - pub struct PrometheusHttpApp; - - #[async_trait] - impl ServeHttp for PrometheusHttpApp { - async fn response(&self, _http_session: &mut ServerSession) -> Response> { - let encoder = TextEncoder::new(); - let metric_families = prometheus::gather(); - let mut buffer = vec![]; - encoder.encode(&metric_families, &mut buffer).unwrap(); - Response::builder() - .status(200) - .header(http::header::CONTENT_TYPE, encoder.format_type()) - .header(http::header::CONTENT_LENGTH, buffer.len()) - .body(buffer) - .unwrap() - } - } - - /// The [HttpServer] for [PrometheusHttpApp] - /// - /// This type provides the functionality of [PrometheusHttpApp] with compression enabled - pub type PrometheusServer = HttpServer; - - impl PrometheusServer { - pub fn new() -> Self { - let mut server = Self::new_app(PrometheusHttpApp); - // enable gzip level 7 compression - server.add_module(ResponseCompressionBuilder::enable(7)); - server - } - } -} - -#[cfg(feature = "prometheus")] -pub use prometheus_impl::*; diff --git a/pingora-core/src/connectors/http/mod.rs b/pingora-core/src/connectors/http/mod.rs index 2545cf7c..5a671ef2 100644 --- a/pingora-core/src/connectors/http/mod.rs +++ b/pingora-core/src/connectors/http/mod.rs @@ -21,6 +21,8 @@ use crate::protocols::http::client::HttpSession; use crate::protocols::http::v1::client::HttpSession as Http1Session; use crate::upstreams::peer::Peer; use pingora_error::Result; +use std::sync::atomic::AtomicU64; +use std::sync::Arc; use std::time::Duration; pub mod custom; @@ -151,6 +153,17 @@ where pub fn prefer_h1(&self, peer: &impl Peer) { self.h2.prefer_h1(peer); } + + /// Return the number of times a pooled connection was found to contain + /// unexpected data from the server. + pub fn unexpected_data_connection_count(&self) -> u64 { + self.h1.unexpected_data_connection_count() + } + + /// Return a shared reference to the unexpected data connection counter for periodic metric reporting. + pub fn unexpected_data_connection_counter(&self) -> Arc { + self.h1.unexpected_data_connection_counter() + } } #[cfg(test)] diff --git a/pingora-core/src/connectors/http/v1.rs b/pingora-core/src/connectors/http/v1.rs index 62ecfcb6..ab04b2f6 100644 --- a/pingora-core/src/connectors/http/v1.rs +++ b/pingora-core/src/connectors/http/v1.rs @@ -17,6 +17,8 @@ use crate::protocols::http::v1::client::HttpSession; use crate::upstreams::peer::Peer; use pingora_error::Result; +use std::sync::atomic::AtomicU64; +use std::sync::Arc; use std::time::Duration; pub struct Connector { @@ -60,6 +62,17 @@ impl Connector { .release_stream(stream, peer.reuse_hash(), idle_timeout); } } + + /// Return the number of times a pooled connection was found to contain + /// unexpected data from the server. + pub fn unexpected_data_connection_count(&self) -> u64 { + self.transport.unexpected_data_connection_count() + } + + /// Return a shared reference to the unexpected data connection counter for periodic metric reporting. + pub fn unexpected_data_connection_counter(&self) -> Arc { + self.transport.unexpected_data_connection_counter() + } } #[cfg(test)] diff --git a/pingora-core/src/connectors/http/v2.rs b/pingora-core/src/connectors/http/v2.rs index c5ec42db..3cde4b89 100644 --- a/pingora-core/src/connectors/http/v2.rs +++ b/pingora-core/src/connectors/http/v2.rs @@ -24,9 +24,10 @@ use bytes::Bytes; use h2::client::SendRequest; use log::debug; use parking_lot::{Mutex, RwLock}; -use pingora_error::{Error, ErrorType::*, OrErr, Result}; +use pingora_error::{Error, ErrorType::*, OkOrErr, OrErr, Result}; use pingora_pool::{ConnectionMeta, ConnectionPool, PoolNode}; use std::collections::HashMap; +use std::io::ErrorKind; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; use std::time::Duration; @@ -152,15 +153,23 @@ impl ConnectionRef { Err(e) => { // fail to create the stream, reset the counter self.0.current_streams.fetch_sub(1, Ordering::SeqCst); - // Remote sends GOAWAY(NO_ERROR): graceful shutdown: this connection no longer - // accepts new streams. We can still try to create new connection. - if e.root_cause() + + // Check for graceful shutdown conditions where we can retry with a new connection + let is_graceful_shutdown = e + .root_cause() .downcast_ref::() .map(|e| { - e.is_go_away() && e.is_remote() && e.reason() == Some(h2::Reason::NO_ERROR) + // Remote sends GOAWAY(NO_ERROR): graceful shutdown + (e.is_go_away() && e.is_remote() && e.reason() == Some(h2::Reason::NO_ERROR)) + // Or broken pipe wrapped inside an h2::Error: stream closed unexpectedly + || (e.is_io() + && e.get_io() + .map(|io| io.kind() == ErrorKind::BrokenPipe) + .unwrap_or(false)) }) - .unwrap_or(false) - { + .unwrap_or(false); + + if is_graceful_shutdown { self.mark_shutdown(); Ok(None) } else { @@ -334,12 +343,17 @@ impl Connector { // the caller that the server speaks h2c } } - let max_h2_stream = peer.get_peer_options().map_or(1, |o| o.max_h2_streams); - let conn = handshake(stream, max_h2_stream, peer.h2_ping_interval()).await?; - let h2_stream = conn - .spawn_stream() - .await? - .expect("newly created connections should have at least one free stream"); + let peer_options = peer.get_peer_options(); + let mut settings = H2HandshakeSettings::new(); + settings.max_streams = peer_options.map_or(1, |o| o.max_h2_streams); + settings.ping_interval = peer.h2_ping_interval(); + settings.stream_window_size = peer_options.and_then(|o| o.h2_stream_window_size); + settings.connection_window_size = peer_options.and_then(|o| o.h2_connection_window_size); + let conn = handshake(stream, settings).await?; + let h2_stream = conn.spawn_stream().await?.or_err( + H2Error, + "newly created connection has no free streams (server may have sent GOAWAY)", + )?; if conn.more_streams_allowed() { self.in_use_pool.insert(peer.reuse_hash(), conn); } @@ -475,19 +489,84 @@ impl Connector { // 8 Mbytes = 80 Mbytes X 100ms, which should be enough for most links. const H2_WINDOW_SIZE: u32 = 1 << 23; -pub async fn handshake( - stream: Stream, - max_streams: usize, - h2_ping_interval: Option, -) -> Result { +/// Maximum allowed H2 window size per [RFC 9113 §6.9.1](https://datatracker.ietf.org/doc/html/rfc9113#section-6.9.1-7). +const H2_MAX_WINDOW_SIZE: u32 = (1u32 << 31) - 1; + +/// Settings for HTTP/2 handshake. +/// +/// # Example +/// +/// ```rust,ignore +/// use pingora_core::connectors::http::v2::{handshake, H2HandshakeSettings}; +/// +/// // With custom window sizes +/// let mut settings = H2HandshakeSettings::new(); +/// settings.max_streams = 100; +/// settings.stream_window_size = Some(1 << 20); // 1MiB +/// settings.connection_window_size = Some(1 << 24); // 16MiB +/// let conn = handshake(stream, settings).await?; +/// ``` +#[derive(Debug, Clone, Default)] +#[non_exhaustive] +pub struct H2HandshakeSettings { + /// The maximum number of concurrent streams allowed on this connection. + pub max_streams: usize, + /// Optional interval for sending H2 ping frames to keep the connection alive. + pub ping_interval: Option, + /// Optional initial per-stream receive window size in bytes. + /// If `None`, the default of 8MB is used. + pub stream_window_size: Option, + /// Optional initial connection-level receive window size in bytes. + /// If `None`, the default of 8MB is used. + pub connection_window_size: Option, +} + +impl H2HandshakeSettings { + /// Create a new `H2HandshakeSettings` with all defaults. + pub fn new() -> Self { + Self::default() + } +} + +/// Perform an HTTP/2 handshake on the given stream with the given settings. +pub async fn handshake(stream: Stream, settings: H2HandshakeSettings) -> Result { use h2::client::Builder; use pingora_runtime::current_handle; + let max_streams = settings.max_streams; + // Safe guard: new_http_session() assumes there should be at least one free stream if max_streams == 0 { return Error::e_explain(H2Error, "zero max_stream configured"); } + // Validate window sizes against RFC 9113 §6.9.1 limit + // https://datatracker.ietf.org/doc/html/rfc9113#section-6.9.1-7 + if settings + .stream_window_size + .is_some_and(|w| w == 0 || w > H2_MAX_WINDOW_SIZE) + { + return Error::e_explain( + H2Error, + format!( + "stream_window_size must be between 1 and {} (2^31-1)", + H2_MAX_WINDOW_SIZE + ), + ); + } + if settings + .connection_window_size + .is_some_and(|w| w == 0 || w > H2_MAX_WINDOW_SIZE) + { + return Error::e_explain( + H2Error, + format!( + "connection_window_size must be between 1 and {} (2^31-1)", + H2_MAX_WINDOW_SIZE + ), + ); + } + let id = stream.id(); let digest = Digest { // NOTE: this field is always false because the digest is shared across all streams @@ -498,16 +577,16 @@ pub async fn handshake( proxy_digest: stream.get_proxy_digest(), socket_digest: stream.get_socket_digest(), }; - // TODO: make these configurable + let stream_window = settings.stream_window_size.unwrap_or(H2_WINDOW_SIZE); + let conn_window = settings.connection_window_size.unwrap_or(H2_WINDOW_SIZE); let (send_req, connection) = Builder::new() .enable_push(false) .initial_max_send_streams(max_streams) // The limit for the server. Server push is not allowed, so this value doesn't matter .max_concurrent_streams(1) .max_frame_size(64 * 1024) // advise server to send larger frames - .initial_window_size(H2_WINDOW_SIZE) - // should this be max_streams * H2_WINDOW_SIZE? - .initial_connection_window_size(H2_WINDOW_SIZE) + .initial_window_size(stream_window) + .initial_connection_window_size(conn_window) .handshake(stream) .await .or_err(HandshakeError, "during H2 handshake")?; @@ -529,7 +608,7 @@ pub async fn handshake( connection, id, closed_tx, - h2_ping_interval, + settings.ping_interval, ping_timeout_clone, ) .await; @@ -549,6 +628,9 @@ pub async fn handshake( mod tests { use super::*; use crate::upstreams::peer::HttpPeer; + use bytes::Bytes; + use http::{Response, StatusCode}; + use pingora_http::RequestHeader; #[tokio::test] #[cfg(feature = "any_tls")] @@ -682,6 +764,46 @@ mod tests { assert_eq!(id, h2_5.conn.id()); } + /// `spawn_stream` must return `Ok(None)` and mark the connection as shutting + /// down when the underlying I/O channel is closed (BrokenPipe). This + /// exercises the BrokenPipe branch of `spawn_stream` directly without going + /// through the full proxy stack. + #[tokio::test] + async fn test_spawn_stream_broken_pipe_marks_shutdown() { + let (client_io, server_io) = tokio::io::duplex(65536); + let (send_req, connection) = h2::client::handshake(client_io).await.unwrap(); + let (closed_tx, closed_rx) = watch::channel(false); + let ping_timeout = Arc::new(AtomicBool::new(false)); + let conn = ConnectionRef::new(send_req, closed_rx, ping_timeout, 0, 10, Digest::default()); + + // Drive the H2 client connection task in the background. + // When the connection terminates it will signal via closed_tx. + let conn_handle = tokio::spawn(async move { + let _ = connection.await; + // Signal that the connection task has finished. + let _ = closed_tx.send(true); + }); + + // Complete the server-side H2 handshake, then drop the server connection. + // Dropping server_conn closes the write end of the duplex, so the client + // connection task will read EOF and terminate with BrokenPipe. + let server_conn = h2::server::handshake(server_io).await.unwrap(); + drop(server_conn); + + // Wait until the client connection task has fully processed the EOF. + conn_handle.await.unwrap(); + + // spawn_stream must detect BrokenPipe, mark shutdown, and return Ok(None) + // so the caller can retry on a fresh connection rather than propagating the error. + let result = conn.spawn_stream().await; + assert!(result.is_ok(), "expected Ok(None), got Err"); + assert!(result.unwrap().is_none(), "expected None stream"); + assert!( + conn.is_shutting_down(), + "connection should be marked as shutting down" + ); + } + #[tokio::test] async fn test_mark_shutdown_prevents_new_streams() { let (client_io, _server_io) = tokio::io::duplex(65536); @@ -769,4 +891,142 @@ mod tests { .unwrap() .is_none()); } + + #[tokio::test] + async fn test_h2_handshake_settings_validation() { + use super::H2HandshakeSettings; + + // Test zero stream window size is rejected + let mut settings = H2HandshakeSettings::new(); + settings.max_streams = 100; + settings.stream_window_size = Some(0); + let (client, _server) = tokio::io::duplex(65536); + match handshake(Box::new(client), settings).await { + Err(e) => assert!( + e.to_string() + .contains("stream_window_size must be between 1"), + "Unexpected error: {}", + e + ), + Ok(_) => panic!("Expected error for stream_window_size = 0"), + } + + // Test zero connection window size is rejected + let mut settings = H2HandshakeSettings::new(); + settings.max_streams = 100; + settings.connection_window_size = Some(0); + let (client, _server) = tokio::io::duplex(65536); + match handshake(Box::new(client), settings).await { + Err(e) => assert!( + e.to_string() + .contains("connection_window_size must be between 1"), + "Unexpected error: {}", + e + ), + Ok(_) => panic!("Expected error for connection_window_size = 0"), + } + + // Test exceeding max stream window size is rejected + let mut settings = H2HandshakeSettings::new(); + settings.max_streams = 100; + settings.stream_window_size = Some(super::H2_MAX_WINDOW_SIZE + 1); + let (client, _server) = tokio::io::duplex(65536); + match handshake(Box::new(client), settings).await { + Err(e) => assert!( + e.to_string() + .contains("stream_window_size must be between 1"), + "Unexpected error: {}", + e + ), + Ok(_) => panic!("Expected error for stream_window_size > max"), + } + + // Test exceeding max connection window size is rejected + let mut settings = H2HandshakeSettings::new(); + settings.max_streams = 100; + settings.connection_window_size = Some(super::H2_MAX_WINDOW_SIZE + 1); + let (client, _server) = tokio::io::duplex(65536); + match handshake(Box::new(client), settings).await { + Err(e) => assert!( + e.to_string() + .contains("connection_window_size must be between 1"), + "Unexpected error: {}", + e + ), + Ok(_) => panic!("Expected error for connection_window_size > max"), + } + } + + #[tokio::test] + async fn test_h2_handshake_custom_window_sizes() { + // Test that valid custom window sizes are accepted and handshake succeeds + let mut settings = H2HandshakeSettings::new(); + settings.max_streams = 100; + settings.stream_window_size = Some(1 << 20); // 1MiB + settings.connection_window_size = Some(1 << 24); // 16MiB + + let (client, server) = tokio::io::duplex(65536); + + // Spawn server side + tokio::spawn(async move { + let mut server_conn = h2::server::handshake(server).await.unwrap(); + if let Some(result) = server_conn.accept().await { + let (_request, mut respond) = result.unwrap(); + let resp = Response::builder().status(StatusCode::OK).body(()).unwrap(); + let mut stream = respond.send_response(resp, false).unwrap(); + stream.send_data(Bytes::from("ok"), true).unwrap(); + server_conn.graceful_shutdown(); + } + // Drive the server connection until the client closes + while let Some(_res) = server_conn.accept().await {} + }); + + // Client side - should succeed with custom window sizes + let conn = handshake(Box::new(client), settings).await.unwrap(); + + // Verify we can spawn a stream and complete a request/response cycle + let mut stream = conn.spawn_stream().await.unwrap().unwrap(); + let mut request = RequestHeader::build("GET", b"/", None).unwrap(); + request + .insert_header(http::header::HOST, "example.com") + .unwrap(); + stream + .write_request_header(Box::new(request), true) + .unwrap(); + + stream.read_response_header().await.unwrap(); + assert_eq!(stream.response_header().unwrap().status, 200); + } + + /// `spawn_stream()` must return `Ok(None)` when the server sends + /// GOAWAY(NO_ERROR) before any streams are opened. + #[tokio::test] + async fn test_spawn_stream_goaway_no_error_returns_none() { + let (client_io, server_io) = tokio::io::duplex(65536); + let (send_req, connection) = h2::client::handshake(client_io).await.unwrap(); + let (closed_tx, closed_rx) = watch::channel(false); + let ping_timeout = Arc::new(AtomicBool::new(false)); + let conn = ConnectionRef::new(send_req, closed_rx, ping_timeout, 0, 10, Digest::default()); + + let conn_handle = tokio::spawn(async move { + let _ = connection.await; + let _ = closed_tx.send(true); + }); + + let mut server_conn = h2::server::handshake(server_io).await.unwrap(); + server_conn.graceful_shutdown(); + let _ = server_conn.accept().await; + drop(server_conn); + + conn_handle.await.unwrap(); + + let result = conn.spawn_stream().await; + assert!( + result.is_ok(), + "expected Ok(None), got Err: {:?}", + result.as_ref().err() + ); + assert!(result.unwrap().is_none()); + assert!(conn.is_shutting_down()); + } } diff --git a/pingora-core/src/connectors/l4.rs b/pingora-core/src/connectors/l4.rs index bd7439d4..d275030f 100644 --- a/pingora-core/src/connectors/l4.rs +++ b/pingora-core/src/connectors/l4.rs @@ -313,15 +313,9 @@ mod tests { use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time::{Duration, Instant}; - use tokio::io::AsyncWriteExt; use tokio::time::sleep; - /// Some of the tests below are flaky when making new connections to mock - /// servers. The servers are simple tokio listeners, so failures there are - /// not indicative of real errors. This function will retry the peer/server - /// in increasing intervals until it either succeeds in connecting or a long - /// timeout expires (max 10sec) - #[cfg(unix)] + #[cfg(target_os = "linux")] async fn wait_for_peer

(peer: &P) where P: Peer + Send + Sync, @@ -394,11 +388,17 @@ mod tests { #[tokio::test] async fn test_conn_timeout() { - // 192.0.2.1 is effectively a blackhole + // 192.0.2.1 is TEST-NET-1 (RFC 5737) — SYN packets are silently + // dropped on Linux, producing ConnectTimedout. On macOS the kernel + // may instead return ENETUNREACH (ConnectNoRoute). let mut peer = BasicPeer::new("192.0.2.1:79"); - peer.options.connection_timeout = Some(std::time::Duration::from_millis(1)); //1ms - let new_session = connect(&peer, None).await; - assert_eq!(new_session.unwrap_err().etype(), &ConnectTimedout) + peer.options.connection_timeout = Some(Duration::from_millis(1)); + let err = connect(&peer, None).await.unwrap_err(); + assert!( + err.etype() == &ConnectTimedout || err.etype() == &ConnectNoRoute, + "unexpected error type: {:?}", + err.etype() + ); } #[tokio::test] @@ -412,7 +412,7 @@ mod tests { let move_flag = Arc::clone(&flag); peer.options.upstream_tcp_sock_tweak_hook = Some(Arc::new(move |_| { - move_flag.fetch_xor(true, Ordering::SeqCst); + move_flag.fetch_not(Ordering::SeqCst); Ok(()) })); @@ -537,6 +537,7 @@ mod tests { // one-off mock server async fn mock_inet_connect_server() -> u16 { + use tokio::io::AsyncWriteExt; use tokio::net::TcpListener; let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); diff --git a/pingora-core/src/connectors/mod.rs b/pingora-core/src/connectors/mod.rs index 3e3c1c46..35067fa3 100644 --- a/pingora-core/src/connectors/mod.rs +++ b/pingora-core/src/connectors/mod.rs @@ -37,6 +37,7 @@ use pingora_error::{Error, ErrorType::*, OrErr, Result}; use pingora_pool::{ConnectionMeta, ConnectionPool}; use std::collections::HashMap; use std::net::SocketAddr; +use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use tls::TlsConnector; use tokio::sync::Mutex; @@ -146,6 +147,9 @@ pub struct TransportConnector { bind_to_v4: Vec, bind_to_v6: Vec, preferred_http_version: PreferredHttpVersion, + /// Wrapped in `Arc` so external consumers (e.g. proxy services) can clone a reference + /// for periodic metric reporting without needing access to the connector itself. + unexpected_data_conn_count: Arc, } const DEFAULT_POOL_SIZE: usize = 128; @@ -172,6 +176,7 @@ impl TransportConnector { bind_to_v4, bind_to_v6, preferred_http_version: PreferredHttpVersion::new(), + unexpected_data_conn_count: Arc::new(AtomicU64::new(0)), } } @@ -212,7 +217,9 @@ impl TransportConnector { // test_reusable_stream: we assume server would never actively send data // first on an idle stream. #[cfg(unix)] - if peer.matches_fd(stream.id()) && test_reusable_stream(&mut stream) { + if peer.matches_fd(stream.id()) + && test_reusable_stream(&mut stream, &self.unexpected_data_conn_count) + { Some(stream) } else { None @@ -227,7 +234,10 @@ impl TransportConnector { } } if peer.matches_sock(WrappedRawSocket(stream.id() as RawSocket)) - && test_reusable_stream(&mut stream) + && test_reusable_stream( + &mut stream, + &self.unexpected_data_conn_count, + ) { Some(stream) } else { @@ -261,7 +271,7 @@ impl TransportConnector { key: u64, // usually peer.reuse_hash() idle_timeout: Option, ) { - if !test_reusable_stream(&mut stream) { + if !test_reusable_stream(&mut stream, &self.unexpected_data_conn_count) { return; } let id = stream.id(); @@ -301,6 +311,21 @@ impl TransportConnector { pub fn prefer_h1(&self, peer: &impl Peer) { self.preferred_http_version.add(peer, 1); } + + /// Return the number of times a pooled connection was found to contain unexpected data + /// from the server. + pub fn unexpected_data_connection_count(&self) -> u64 { + self.unexpected_data_conn_count.load(Ordering::Relaxed) + } + + /// Return a shared reference to the unexpected data connection counter. + /// + /// This allows external consumers (e.g. proxy services) to clone the `Arc` and + /// periodically read the counter for metric reporting without needing ongoing + /// access to the connector. + pub fn unexpected_data_connection_counter(&self) -> Arc { + self.unexpected_data_conn_count.clone() + } } // Perform the actual L4 and tls connection steps while respecting the peer's @@ -376,7 +401,7 @@ use futures::future::FutureExt; use tokio::io::AsyncReadExt; /// Test whether a stream is already closed or not reusable (server sent unexpected data) -fn test_reusable_stream(stream: &mut Stream) -> bool { +fn test_reusable_stream(stream: &mut Stream, unexpected_data_conn_count: &AtomicU64) -> bool { let mut buf = [0; 1]; // tokio::task::unconstrained because now_or_never may yield None when the future is ready let result = tokio::task::unconstrained(stream.read(&mut buf[..])).now_or_never(); @@ -387,6 +412,7 @@ fn test_reusable_stream(stream: &mut Stream) -> bool { debug!("Idle connection is closed"); } else { warn!("Unexpected data read in idle connection"); + unexpected_data_conn_count.fetch_add(1, Ordering::Relaxed); } } Err(e) => { @@ -482,15 +508,14 @@ pub(crate) mod test_utils { #[cfg(test)] #[cfg(feature = "any_tls")] mod tests { + use std::time::Duration; + use pingora_error::ErrorType; use tls::Connector; use super::*; use crate::upstreams::peer::BasicPeer; - // 192.0.2.1 is effectively a black hole - const BLACK_HOLE: &str = "192.0.2.1:79"; - #[tokio::test] async fn test_connect() { let connector = TransportConnector::new(None); @@ -547,15 +572,23 @@ mod tests { server_handle.await.unwrap(); } + // 192.0.2.1 is TEST-NET-1 (RFC 5737) — SYN packets are silently + // dropped on Linux, producing ConnectTimedout. On macOS the kernel + // may instead return ENETUNREACH (ConnectNoRoute). + const BLACKHOLE: &str = "192.0.2.1:79"; + async fn do_test_conn_timeout(conf: Option) { let connector = TransportConnector::new(conf); - let mut peer = BasicPeer::new(BLACK_HOLE); - peer.options.connection_timeout = Some(std::time::Duration::from_millis(1)); - let stream = connector.new_stream(&peer).await; - match stream { - Ok(_) => panic!("should throw an error"), - Err(e) => assert_eq!(e.etype(), &ConnectTimedout), - } + let mut peer = BasicPeer::new(BLACKHOLE); + peer.options.connection_timeout = Some(Duration::from_millis(1)); + let Err(e) = connector.new_stream(&peer).await else { + panic!("should throw an error"); + }; + assert!( + e.etype() == &ConnectTimedout || e.etype() == &ConnectNoRoute, + "unexpected error type: {:?}", + e.etype() + ); } #[tokio::test] @@ -580,13 +613,21 @@ mod tests { let stream = connector.new_stream(&peer).await; let error = stream.unwrap_err(); - // XXX: some systems will allow the socket to bind and connect without error, only to timeout - assert!(error.etype() == &ConnectError || error.etype() == &ConnectTimedout) + // The exact error varies by platform: Linux may return ConnectError, + // some systems time out (ConnectTimedout), and macOS/others may + // return ConnectNoRoute (ENETUNREACH) for unreachable addresses. + assert!( + error.etype() == &ConnectError + || error.etype() == &ConnectTimedout + || error.etype() == &ConnectNoRoute, + "unexpected error type: {:?}", + error.etype() + ) } /// Helper function for testing error handling in the `do_connect` function. - /// This assumes that the connection will fail to on the peer and returns - /// the decomposed error type and message + /// This assumes that the connection will fail on the peer and returns + /// the decomposed error type and message. async fn get_do_connect_failure_with_peer(peer: &BasicPeer) -> (ErrorType, String) { let tls_connector = Connector::new(None); let stream = do_connect(peer, None, None, &tls_connector.ctx).await; @@ -604,27 +645,65 @@ mod tests { #[tokio::test] async fn test_do_connect_with_total_timeout() { - let mut peer = BasicPeer::new(BLACK_HOLE); - peer.options.total_connection_timeout = Some(std::time::Duration::from_millis(1)); + let mut peer = BasicPeer::new(BLACKHOLE); + peer.options.total_connection_timeout = Some(Duration::from_millis(1)); let (etype, context) = get_do_connect_failure_with_peer(&peer).await; - assert_eq!(etype, ConnectTimedout); - assert!(context.contains("total-connection timeout")); + assert!( + etype == ConnectTimedout || etype == ConnectNoRoute, + "unexpected error type: {etype:?}" + ); + if etype == ConnectTimedout { + assert!(context.contains("total-connection timeout")); + } } #[tokio::test] async fn test_tls_connect_timeout_supersedes_total() { - let mut peer = BasicPeer::new(BLACK_HOLE); - peer.options.total_connection_timeout = Some(std::time::Duration::from_millis(10)); - peer.options.connection_timeout = Some(std::time::Duration::from_millis(1)); + let mut peer = BasicPeer::new(BLACKHOLE); + peer.options.total_connection_timeout = Some(Duration::from_millis(10)); + peer.options.connection_timeout = Some(Duration::from_millis(1)); let (etype, context) = get_do_connect_failure_with_peer(&peer).await; - assert_eq!(etype, ConnectTimedout); - assert!(!context.contains("total-connection timeout")); + assert!( + etype == ConnectTimedout || etype == ConnectNoRoute, + "unexpected error type: {etype:?}" + ); + if etype == ConnectTimedout { + assert!(!context.contains("total-connection timeout")); + } } #[tokio::test] async fn test_do_connect_without_total_timeout() { - let peer = BasicPeer::new(BLACK_HOLE); + let peer = BasicPeer::new(BLACKHOLE); let (etype, context) = get_do_connect_failure_with_peer(&peer).await; assert!(etype != ConnectTimedout || !context.contains("total-connection timeout")); } + + #[tokio::test] + async fn test_unexpected_data_connection_count_increments() { + // Create a duplex stream where we control both ends + let (mut server, client) = tokio::io::duplex(64); + + let counter = AtomicU64::new(0); + let mut stream: Stream = Box::new(client); + + // With no data available, the stream should be considered reusable + assert!(test_reusable_stream(&mut stream, &counter)); + assert_eq!(counter.load(Ordering::Relaxed), 0); + + // Write unexpected data from the server side + use tokio::io::AsyncWriteExt; + server.write_all(b"unexpected").await.unwrap(); + + // Give the data a moment to be buffered + tokio::task::yield_now().await; + + // Now test_reusable_stream should detect the unexpected data + assert!(!test_reusable_stream(&mut stream, &counter)); + assert_eq!( + counter.load(Ordering::Relaxed), + 1, + "unexpected_data_connection_count should have incremented" + ); + } } diff --git a/pingora-core/src/connectors/tls/boringssl_openssl/mod.rs b/pingora-core/src/connectors/tls/boringssl_openssl/mod.rs index 9bb3a5a6..8585deef 100644 --- a/pingora-core/src/connectors/tls/boringssl_openssl/mod.rs +++ b/pingora-core/src/connectors/tls/boringssl_openssl/mod.rs @@ -193,7 +193,7 @@ where } } - if let Some(curve) = peer.get_peer_options().and_then(|o| o.curves) { + if let Some(curve) = peer.get_peer_options().and_then(|o| o.curves.as_deref()) { ssl_set_groups_list(&mut ssl_conf, curve).or_err(InternalError, "invalid curves")?; } diff --git a/pingora-core/src/connectors/tls/rustls/mod.rs b/pingora-core/src/connectors/tls/rustls/mod.rs index 58ea4085..23e3a307 100644 --- a/pingora-core/src/connectors/tls/rustls/mod.rs +++ b/pingora-core/src/connectors/tls/rustls/mod.rs @@ -61,6 +61,9 @@ impl TlsConnector { where Self: Sized, { + // rustls 0.23+ requires an explicit CryptoProvider. + pingora_rustls::install_default_crypto_provider(); + // NOTE: Rustls only supports TLS 1.2 & 1.3 // TODO: currently using Rustls defaults diff --git a/pingora-core/src/lib.rs b/pingora-core/src/lib.rs index a4450632..b66282c9 100644 --- a/pingora-core/src/lib.rs +++ b/pingora-core/src/lib.rs @@ -105,6 +105,11 @@ pub mod utils; pub use listeners::set_client_hello_callback; pub use listeners::ClientHelloCallback; +// Re-export PROXY v2 extension-TLV callback for app-defined metadata +// (e.g. synapse's per-flow fingerprint store, see gen0sec/synapse#352). +pub use listeners::set_proxy_v2_tlv_callback; +pub use listeners::ProxyV2TlvCallback; + pub use pingora_error::{ErrorType::*, *}; // If both openssl and boringssl are enabled, prefer boringssl. diff --git a/pingora-core/src/listeners/l4.rs b/pingora-core/src/listeners/l4.rs index 1c0052f8..5532b635 100644 --- a/pingora-core/src/listeners/l4.rs +++ b/pingora-core/src/listeners/l4.rs @@ -44,6 +44,20 @@ use crate::protocols::GetSocketDigest; use crate::protocols::TcpKeepalive; #[cfg(unix)] use crate::server::ListenFds; +#[cfg(unix)] +use std::sync::LazyLock; + +/// Per-address async lock map for serializing the check-bind-insert sequence +/// in [`ListenerEndpointBuilder::listen`]. +/// +/// With `ListenFds` using a synchronous `parking_lot::Mutex`, the lock cannot +/// be held across `bind().await`. This global map ensures that only one task at +/// a time can be in the process of looking up, binding, and inserting a given +/// address — preventing two concurrent callers from both seeing "not found" and +/// racing to bind the same address. +#[cfg(unix)] +static BIND_LOCKS: LazyLock>>> = + LazyLock::new(flurry::HashMap::new); const TCP_LISTENER_MAX_TRY: usize = 30; const TCP_LISTENER_TRY_STEP: Duration = Duration::from_secs(1); @@ -326,19 +340,38 @@ impl ListenerEndpointBuilder { let listener = if let Some(fds_table) = fds { let addr_str = listen_addr.as_ref(); - // consider make this mutex std::sync::Mutex or OnceCell - let mut table = fds_table.lock().await; + // Acquire a per-address async lock so that only one task at a + // time can go through the check-bind-insert sequence for a given + // address. The flurry guard is dropped before the await so its + // !Send pointer does not cross an await point. + let addr_lock = { + let guard = BIND_LOCKS.pin(); + match guard.get(addr_str) { + Some(existing) => existing.clone(), + None => { + let new_lock = Arc::new(tokio::sync::Mutex::new(())); + match guard.try_insert(addr_str.to_string(), new_lock.clone()) { + Ok(inserted) => inserted.clone(), + Err(e) => e.current.clone(), + } + } + } + }; + let _guard = addr_lock.lock().await; + + let existing_fd = fds_table.lock().get(addr_str).copied(); - if let Some(fd) = table.get(addr_str) { - from_raw_fd(&listen_addr, *fd)? + if let Some(fd) = existing_fd { + from_raw_fd(&listen_addr, fd)? } else { - // not found let listener = bind(&listen_addr).await?; - table.add(addr_str.to_string(), listener.as_raw_fd()); + fds_table + .lock() + .add(addr_str.to_string(), listener.as_raw_fd()); listener } } else { - // not found, no fd table + // no fd table bind(&listen_addr).await? }; @@ -386,6 +419,15 @@ impl ListenerEndpoint { self.listen_addr.as_ref() } + /// Return the local address this endpoint is bound to. + /// + /// Useful when the listener was bound to port 0 (OS-assigned) to + /// discover the actual port. + #[cfg(test)] + pub fn local_addr(&self) -> Option { + self.listener.local_addr() + } + fn apply_stream_settings(&self, stream: &mut Stream) -> Result<()> { // settings are applied based on whether the underlying stream supports it stream.set_nodelay()?; @@ -470,11 +512,9 @@ mod test { #[tokio::test] async fn test_listen_tcp() { - let addr = "127.0.0.1:7100"; - let mut builder = ListenerEndpoint::builder(); - builder.listen_addr(ServerAddress::Tcp(addr.into(), None)); + builder.listen_addr(ServerAddress::Tcp("127.0.0.1:0".into(), None)); #[cfg(unix)] let listener = builder.listen(None).await.unwrap(); @@ -482,6 +522,8 @@ mod test { #[cfg(windows)] let listener = builder.listen().await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { // just try to accept once listener.accept().await.unwrap(); @@ -500,7 +542,7 @@ mod test { let mut builder = ListenerEndpoint::builder(); - builder.listen_addr(ServerAddress::Tcp("[::]:7101".into(), sock_opt)); + builder.listen_addr(ServerAddress::Tcp("[::]:0".into(), sock_opt)); #[cfg(unix)] let listener = builder.listen(None).await.unwrap(); @@ -508,15 +550,17 @@ mod test { #[cfg(windows)] let listener = builder.listen().await.unwrap(); + let port = listener.local_addr().unwrap().port(); + tokio::spawn(async move { // just try to accept twice listener.accept().await.unwrap(); listener.accept().await.unwrap(); }); - tokio::net::TcpStream::connect("127.0.0.1:7101") + tokio::net::TcpStream::connect(format!("127.0.0.1:{port}")) .await .expect_err("cannot connect to v4 addr"); - tokio::net::TcpStream::connect("[::1]:7101") + tokio::net::TcpStream::connect(format!("[::1]:{port}")) .await .expect("can connect to v6 addr"); } diff --git a/pingora-core/src/listeners/mod.rs b/pingora-core/src/listeners/mod.rs index abc65ea1..e209306e 100644 --- a/pingora-core/src/listeners/mod.rs +++ b/pingora-core/src/listeners/mod.rs @@ -129,6 +129,66 @@ fn call_client_hello_callback( } } +/// Callback function type for parsed PROXY protocol v2 extension TLVs. +/// +/// Invoked from `UninitializedStream::handshake` right after +/// `maybe_consume_proxy_header` parses a v2 header and recovers the +/// real client `SocketAddr`. Lets consumer applications (synapse, +/// custom logging) pull metadata out of application-defined TLVs +/// (HAProxy spec § 2.2 reserves type IDs `0xE0..=0xEF`) without +/// needing to parse the wire bytes themselves. +/// +/// Signature mirrors `ClientHelloCallback`: a pointer fn so the +/// registration can stay `Sync` without an `Arc` shuffle. +pub type ProxyV2TlvCallback = Option; + +/// Global callback for parsed PROXY v2 extension TLVs. Registered by +/// `set_proxy_v2_tlv_callback`, invoked by `call_proxy_v2_tlv_callback`. +static PROXY_V2_TLV_CALLBACK: std::sync::OnceLock> = + std::sync::OnceLock::new(); + +/// Register a callback that receives the parsed PROXY v2 extension +/// TLVs from every accepted connection (in addition to the recovered +/// source address that `maybe_consume_proxy_header` already plumbs +/// into the SocketDigest). +/// +/// # Example +/// ``` +/// use pingora_core::listeners::set_proxy_v2_tlv_callback; +/// use pingora_core::protocols::proxy_protocol::ExtensionTlv; +/// use pingora_core::protocols::l4::socket::SocketAddr; +/// +/// set_proxy_v2_tlv_callback(Some(|tlvs: &[ExtensionTlv], real_addr: SocketAddr| { +/// for tlv in tlvs { +/// if let ExtensionTlv::Custom { type_id: 0xE0, value } = tlv { +/// // Application-defined TLV — decode and act on it. +/// let _ = (&real_addr, value); +/// } +/// } +/// })); +/// ``` +pub fn set_proxy_v2_tlv_callback(callback: ProxyV2TlvCallback) { + PROXY_V2_TLV_CALLBACK.get_or_init(|| std::sync::Mutex::new(callback)); + if let Ok(mut cb) = PROXY_V2_TLV_CALLBACK.get().unwrap().lock() { + *cb = callback; + } +} + +/// Invoke the PROXY v2 TLV callback if registered. No-op when no +/// callback is set (the common case) or when the TLV list is empty. +fn call_proxy_v2_tlv_callback(tlvs: &[proxy_protocol::ExtensionTlv], real_addr: SocketAddr) { + if tlvs.is_empty() { + return; + } + if let Some(cb_guard) = PROXY_V2_TLV_CALLBACK.get() { + if let Ok(cb) = cb_guard.lock() { + if let Some(callback) = *cb { + callback(tlvs, real_addr); + } + } + } +} + #[cfg(unix)] use crate::server::ListenFds; @@ -139,6 +199,9 @@ use std::{any::Any, fs::Permissions, sync::Arc}; use l4::{ListenerEndpoint, Stream as L4Stream}; use tls::{Acceptor, TlsSettings}; +pub use crate::protocols::l4::stream::{ + L4BufferSettings, DEFAULT_L4_READ_BUFFER_SIZE, DEFAULT_L4_WRITE_BUFFER_SIZE, +}; pub use crate::protocols::tls::ALPN; use crate::protocols::GetSocketDigest; pub use l4::{ServerAddress, TcpSocketOptions}; @@ -174,6 +237,7 @@ pub type TlsAcceptCallbacks = Box; struct TransportStackBuilder { l4: ServerAddress, tls: Option, + l4_buffer: L4BufferSettings, #[cfg(feature = "connection_filter")] connection_filter: Option>, } @@ -201,14 +265,94 @@ impl TransportStackBuilder { Ok(TransportStack { l4, tls: self.tls.take().map(|tls| Arc::new(tls.build())), + l4_buffer: self.l4_buffer, }) } } +/// Configuration for one listening endpoint. +/// +/// This configures the endpoint address and endpoint-specific transport +/// settings such as [`TcpSocketOptions`], [`TlsSettings`], and L4 +/// [`BufStream`](tokio::io::BufStream) buffer sizes. +pub struct ListenerConfig { + l4: ServerAddress, + tls: Option, + l4_buffer: L4BufferSettings, +} + +impl ListenerConfig { + /// Create a TCP listening endpoint config. + pub fn tcp(addr: impl Into) -> Self { + Self { + l4: ServerAddress::Tcp(addr.into(), None), + tls: None, + l4_buffer: L4BufferSettings::default(), + } + } + + /// Create a Unix domain socket listening endpoint config. + #[cfg(unix)] + pub fn uds(addr: impl Into) -> Self { + Self { + l4: ServerAddress::Uds(addr.into(), None), + tls: None, + l4_buffer: L4BufferSettings::default(), + } + } + + /// Set TCP socket options for this endpoint. + /// + /// # Panics + /// + /// Panics if this endpoint is not TCP. + #[track_caller] + pub fn tcp_socket_options(mut self, options: TcpSocketOptions) -> Self { + match &mut self.l4 { + ServerAddress::Tcp(_, opt) => *opt = Some(options), + #[cfg(unix)] + ServerAddress::Uds(_, _) => { + panic!("TCP socket options can only be set on TCP endpoints") + } + } + self + } + + /// Set Unix domain socket permissions for this endpoint. + /// + /// # Panics + /// + /// Panics if this endpoint is not a Unix domain socket. + #[cfg(unix)] + #[track_caller] + pub fn permissions(mut self, permissions: Permissions) -> Self { + match &mut self.l4 { + ServerAddress::Uds(_, perm) => *perm = Some(permissions), + ServerAddress::Tcp(_, _) => { + panic!("Unix domain socket permissions can only be set on UDS endpoints") + } + } + self + } + + /// Set TLS settings for this endpoint. + pub fn tls(mut self, settings: TlsSettings) -> Self { + self.tls = Some(settings); + self + } + + /// Set L4 `BufStream` buffer sizes for this endpoint. + pub fn l4_buffer(mut self, settings: L4BufferSettings) -> Self { + self.l4_buffer = settings; + self + } +} + #[derive(Clone)] pub(crate) struct TransportStack { l4: ListenerEndpoint, tls: Option>, + l4_buffer: L4BufferSettings, } impl TransportStack { @@ -221,6 +365,7 @@ impl TransportStack { Ok(UninitializedStream { l4: stream, tls: self.tls.clone(), + l4_buffer: self.l4_buffer, }) } @@ -232,11 +377,16 @@ impl TransportStack { pub(crate) struct UninitializedStream { l4: L4Stream, tls: Option>, + l4_buffer: L4BufferSettings, } impl UninitializedStream { pub async fn handshake(mut self) -> Result { - self.l4.set_buffer(); + // upstream/main gave `set_buffer` an explicit buffer-settings + // argument (`l4_buffer` field on UninitializedStream); keep the + // fork's ClientHello-then-PROXY-header ordering with the new + // signature. + self.l4.set_buffer(self.l4_buffer); // Extract ClientHello BEFORE consuming PROXY headers // This is necessary because PROXY header consumption rewinds ClientHello to a buffer @@ -367,7 +517,8 @@ impl UninitializedStream { match proxy_protocol::consume_proxy_header(&mut self.l4).await { Ok(Some(header)) => { - if let Some(real_addr) = proxy_protocol::source_addr_from_header(&header) { + let real_addr_opt = proxy_protocol::source_addr_from_header(&header); + if let Some(real_addr) = real_addr_opt { if let Some(digest) = self.l4.get_socket_digest() { let client_addr = SocketAddr::Inet(real_addr); digest.set_client_addr(client_addr.clone()); @@ -385,6 +536,19 @@ impl UninitializedStream { } else { debug!("PROXY protocol header contained no client address (LOCAL command)"); } + // Surface any v2 extension TLVs to the registered + // application callback so consumers can ride extra + // metadata (e.g. JA4 fingerprints) on the same hop + // without an out-of-band channel. We need the real + // source address as the key, so skip when source + // recovery failed. + if let Some(real_addr) = real_addr_opt { + if let proxy_protocol::ProxyHeader::Version2 { extensions, .. } = &header { + if !extensions.is_empty() { + call_proxy_v2_tlv_callback(extensions, SocketAddr::Inet(real_addr)); + } + } + } } Ok(None) => { debug!("PROXY protocol is enabled but downstream connection from {} sent no header (connection will continue)", peer_str); @@ -445,18 +609,22 @@ impl Listeners { /// Add a TCP endpoint to `self`. pub fn add_tcp(&mut self, addr: &str) { - self.add_address(ServerAddress::Tcp(addr.into(), None)); + self.add_listener(ListenerConfig::tcp(addr)); } /// Add a TCP endpoint to `self`, with the given [`TcpSocketOptions`]. pub fn add_tcp_with_settings(&mut self, addr: &str, sock_opt: TcpSocketOptions) { - self.add_address(ServerAddress::Tcp(addr.into(), Some(sock_opt))); + self.add_listener(ListenerConfig::tcp(addr).tcp_socket_options(sock_opt)); } /// Add a Unix domain socket endpoint to `self`. #[cfg(unix)] pub fn add_uds(&mut self, addr: &str, perm: Option) { - self.add_address(ServerAddress::Uds(addr.into(), perm)); + let endpoint = perm.map_or_else( + || ListenerConfig::uds(addr), + |perm| ListenerConfig::uds(addr).permissions(perm), + ); + self.add_listener(endpoint); } /// Add a TLS endpoint to `self` with the [Mozilla Intermediate](https://wiki.mozilla.org/Security/Server_Side_TLS#Intermediate_compatibility_.28recommended.29) @@ -474,7 +642,11 @@ impl Listeners { sock_opt: Option, settings: TlsSettings, ) { - self.add_endpoint(ServerAddress::Tcp(addr.into(), sock_opt), Some(settings)); + let mut endpoint = ListenerConfig::tcp(addr).tls(settings); + if let Some(sock_opt) = sock_opt { + endpoint = endpoint.tcp_socket_options(sock_opt); + } + self.add_listener(endpoint); } /// Add the given [`ServerAddress`] to `self`. @@ -496,11 +668,24 @@ impl Listeners { } } - /// Add the given [`ServerAddress`] to `self` with the given [`TlsSettings`] if provided + /// Add the given listener endpoint to `self`. + pub fn add_listener(&mut self, endpoint: ListenerConfig) { + let ListenerConfig { l4, tls, l4_buffer } = endpoint; + self.stacks.push(TransportStackBuilder { + l4, + tls, + l4_buffer, + #[cfg(feature = "connection_filter")] + connection_filter: self.connection_filter.clone(), + }); + } + + /// Add the given [`ServerAddress`] to `self` with the given [`TlsSettings`] if provided. pub fn add_endpoint(&mut self, l4: ServerAddress, tls: Option) { self.stacks.push(TransportStackBuilder { l4, tls, + l4_buffer: L4BufferSettings::default(), #[cfg(feature = "connection_filter")] connection_filter: self.connection_filter.clone(), }) @@ -539,14 +724,11 @@ mod test { #[cfg(feature = "any_tls")] use tokio::io::AsyncWriteExt; use tokio::net::TcpStream; - use tokio::time::{sleep, Duration}; #[tokio::test] async fn test_listen_tcp() { - let addr1 = "127.0.0.1:7101"; - let addr2 = "127.0.0.1:7102"; - let mut listeners = Listeners::tcp(addr1); - listeners.add_tcp(addr2); + let mut listeners = Listeners::tcp("127.0.0.1:0"); + listeners.add_tcp("127.0.0.1:0"); let listeners = listeners .build( @@ -557,6 +739,10 @@ mod test { .unwrap(); assert_eq!(listeners.len(), 2); + let addrs: Vec<_> = listeners + .iter() + .map(|s| s.l4.local_addr().unwrap()) + .collect(); for listener in listeners { tokio::spawn(async move { // just try to accept once @@ -565,11 +751,77 @@ mod test { }); } - // make sure the above starts before the lines below - sleep(Duration::from_millis(10)).await; + // The listeners are already bound (port resolved during build()), + // so the kernel accepts connections into the backlog immediately. + // No readiness wait needed — connect will succeed as soon as the + // OS has completed the TCP handshake. + TcpStream::connect(addrs[0]).await.unwrap(); + TcpStream::connect(addrs[1]).await.unwrap(); + } - TcpStream::connect(addr1).await.unwrap(); - TcpStream::connect(addr2).await.unwrap(); + #[test] + fn test_add_listener_config_tcp_l4_buffer() { + let mut listeners = Listeners::new(); + let tcp_options = TcpSocketOptions { + dscp: Some(10), + ..Default::default() + }; + let l4_buffer = L4BufferSettings { + read: Some(0), + write: None, + }; + + listeners.add_listener( + ListenerConfig::tcp("127.0.0.1:7107") + .tcp_socket_options(tcp_options) + .l4_buffer(l4_buffer), + ); + + assert_eq!(listeners.stacks.len(), 1); + assert_eq!(listeners.stacks[0].l4_buffer, l4_buffer); + assert_eq!(listeners.stacks[0].l4_buffer.read_capacity(), 0); + assert_eq!( + listeners.stacks[0].l4_buffer.write_capacity(), + DEFAULT_L4_WRITE_BUFFER_SIZE + ); + + match &listeners.stacks[0].l4 { + ServerAddress::Tcp(addr, Some(options)) => { + assert_eq!(addr, "127.0.0.1:7107"); + assert_eq!(options.dscp, Some(10)); + } + other => panic!("unexpected listener address: {other:?}"), + } + } + + #[cfg(unix)] + #[test] + fn test_add_listener_config_uds_l4_buffer() { + let mut listeners = Listeners::new(); + let l4_buffer = L4BufferSettings::unbuffered(); + + listeners.add_listener(ListenerConfig::uds("/tmp/test_builder_uds").l4_buffer(l4_buffer)); + + assert_eq!(listeners.stacks.len(), 1); + assert_eq!(listeners.stacks[0].l4_buffer, l4_buffer); + assert_eq!(listeners.stacks[0].l4_buffer.read_capacity(), 0); + assert_eq!(listeners.stacks[0].l4_buffer.write_capacity(), 0); + + match &listeners.stacks[0].l4 { + ServerAddress::Uds(addr, None) => assert_eq!(addr, "/tmp/test_builder_uds"), + other => panic!("unexpected listener address: {other:?}"), + } + } + + #[test] + fn test_l4_buffer_settings_defaults_per_direction() { + let l4_buffer = L4BufferSettings { + read: None, + write: Some(0), + }; + + assert_eq!(l4_buffer.read_capacity(), DEFAULT_L4_READ_BUFFER_SIZE); + assert_eq!(l4_buffer.write_capacity(), 0); } #[tokio::test] @@ -602,9 +854,8 @@ mod test { .await .unwrap(); }); - // make sure the above starts before the lines below - sleep(Duration::from_millis(10)).await; - + // The listener is already bound, so the kernel accepts connections + // into the backlog immediately. No readiness wait needed. let client = reqwest::Client::builder() .danger_accept_invalid_certs(true) .build() diff --git a/pingora-core/src/listeners/tls/rustls/mod.rs b/pingora-core/src/listeners/tls/rustls/mod.rs index 0ca94d51..e7376fc0 100644 --- a/pingora-core/src/listeners/tls/rustls/mod.rs +++ b/pingora-core/src/listeners/tls/rustls/mod.rs @@ -48,6 +48,9 @@ impl TlsSettings { /// /// Todo: Return a result instead of panicking XD pub fn build(self) -> Acceptor { + // rustls 0.23+ requires an explicit CryptoProvider. + pingora_rustls::install_default_crypto_provider(); + let Ok(Some((certs, key))) = load_certs_and_key_files(&self.cert_path, &self.key_path) else { panic!( diff --git a/pingora-core/src/protocols/http/compression/mod.rs b/pingora-core/src/protocols/http/compression/mod.rs index 9e8e96fb..301de0fb 100644 --- a/pingora-core/src/protocols/http/compression/mod.rs +++ b/pingora-core/src/protocols/http/compression/mod.rs @@ -303,9 +303,18 @@ impl ResponseCompressionCtx { Action::Compress(algorithm) => { let idx = algorithm.index(); let compressor = match algorithm { - Algorithm::Dcz => dictionary.as_ref().and_then(|d| { - algorithm.maybe_compressor_with_dictionary(levels[idx], d) - }), + Algorithm::Dcz => { + // RFC 9842: dictionary-compressed responses vary on + // Available-Dictionary so caches don't serve this variant + // to clients with a different or missing dictionary. + let enc = dictionary.as_ref().and_then(|d| { + algorithm.maybe_compressor_with_dictionary(levels[idx], d) + }); + if enc.is_some() { + add_vary_header(resp, &AVAILABLE_DICTIONARY); + } + enc + } _ => algorithm.compressor(levels[idx]), }; (compressor, preserve_etag[idx]) @@ -780,6 +789,13 @@ fn compressible(resp: &ResponseHeader) -> bool { } } +/// Header name for the Available-Dictionary request header ([RFC 9842]). +/// TODO: Replace with http::header when available. +/// +/// [RFC 9842]: https://datatracker.ietf.org/doc/html/rfc9842 +static AVAILABLE_DICTIONARY: http::HeaderName = + http::HeaderName::from_static("available-dictionary"); + // add Vary header with the specified value or extend an existing Vary header value fn add_vary_header(resp: &mut ResponseHeader, value: &http::header::HeaderName) { use http::header::{HeaderValue, VARY}; @@ -1055,6 +1071,11 @@ mod tests_dictionary_compression { resp.headers.get("content-encoding").unwrap().as_bytes(), b"dcz" ); + // RFC 9842: DCZ responses must vary on Available-Dictionary. + assert!(resp.headers.get_all("vary").iter().any(|v| v + .as_bytes() + .split(|b| *b == b',') + .any(|t| t.trim_ascii().eq_ignore_ascii_case(b"available-dictionary")))); let input = Bytes::from_static(b"The quick brown fox jumps over the lazy dog again."); let compressed = ctx.response_body_filter(Some(&input), true).unwrap(); @@ -1080,6 +1101,11 @@ mod tests_dictionary_compression { // no dictionary set, no compression applied assert!(resp.headers.get("content-encoding").is_none()); + // No compression → no Vary: available-dictionary. + assert!(!resp.headers.get_all("vary").iter().any(|v| v + .as_bytes() + .split(|b| *b == b',') + .any(|t| t.trim_ascii().eq_ignore_ascii_case(b"available-dictionary")))); } #[test] @@ -1099,6 +1125,11 @@ mod tests_dictionary_compression { // dcz first but no dictionary, no automatic fallback assert!(resp.headers.get("content-encoding").is_none()); + // No compression → no Vary: available-dictionary. + assert!(!resp.headers.get_all("vary").iter().any(|v| v + .as_bytes() + .split(|b| *b == b',') + .any(|t| t.trim_ascii().eq_ignore_ascii_case(b"available-dictionary")))); } #[test] @@ -1152,6 +1183,11 @@ mod tests_dictionary_compression { resp.headers.get("transfer-encoding").unwrap().as_bytes(), b"chunked" ); + // RFC 9842: DCZ responses must vary on Available-Dictionary. + assert!(resp.headers.get_all("vary").iter().any(|v| v + .as_bytes() + .split(|b| *b == b',') + .any(|t| t.trim_ascii().eq_ignore_ascii_case(b"available-dictionary")))); let chunk1 = Bytes::from_static(b"First chunk. "); let output1 = ctx.response_body_filter(Some(&chunk1), false); @@ -1166,4 +1202,33 @@ mod tests_dictionary_compression { assert_eq!(total_in, chunk1.len() + chunk2.len()); assert!(total_out > 0); } + + #[test] + fn regular_compression_no_available_dictionary_vary() { + // Gzip compression should produce Vary: Accept-Encoding but NOT + // Vary: available-dictionary. + let mut ctx = ResponseCompressionCtx::new(3, false, false); + + let mut req = RequestHeader::build("GET", b"/page.html", None).unwrap(); + req.insert_header("accept-encoding", "gzip").unwrap(); + ctx.request_filter(&req); + + let mut resp = ResponseHeader::build(200, None).unwrap(); + resp.insert_header("content-type", "text/html").unwrap(); + resp.insert_header("content-length", "1000").unwrap(); + ctx.response_header_filter(&mut resp, false); + + assert_eq!( + resp.headers.get("content-encoding").unwrap().as_bytes(), + b"gzip" + ); + assert!(resp.headers.get_all("vary").iter().any(|v| v + .as_bytes() + .split(|b| *b == b',') + .any(|t| t.trim_ascii().eq_ignore_ascii_case(b"accept-encoding")))); + assert!(!resp.headers.get_all("vary").iter().any(|v| v + .as_bytes() + .split(|b| *b == b',') + .any(|t| t.trim_ascii().eq_ignore_ascii_case(b"available-dictionary")))); + } } diff --git a/pingora-core/src/protocols/http/server.rs b/pingora-core/src/protocols/http/server.rs index 051cc1f4..bdf2bdc5 100644 --- a/pingora-core/src/protocols/http/server.rs +++ b/pingora-core/src/protocols/http/server.rs @@ -457,6 +457,22 @@ impl Session { } } + /// Controls behaviour when the client closes the connection after the request body. + /// + /// When **enabled** (default), a client close is returned as a `ConnectionClosed` + /// error so the proxy aborts immediately. When **disabled**, `read_body_or_idle` + /// stays pending so the proxy can finish delivering the upstream response. + /// + /// Only meaningful for H1 (TCP). Noop for H2/subrequest/custom. + pub fn set_abort_on_close(&mut self, abort: bool) { + match self { + Self::H1(s) => s.set_abort_on_close(abort), + Self::H2(_) => {} + Self::Subrequest(_) => {} + Self::Custom(_) => {} + } + } + /// Return a digest of the request including the method, path and Host header // TODO: make this use a `Formatter` pub fn request_summary(&self) -> String { @@ -795,4 +811,72 @@ impl Session { Self::Custom(_) => None, } } + + /// Check if this session supports the cancel-safe proxy task API. + /// + /// Currently supported by HTTP/1.x and Subrequest server sessions; + /// toggled per-session via [`set_proxy_tasks_enabled`](Self::set_proxy_tasks_enabled). + pub fn supports_proxy_task_api(&self) -> bool { + match self { + Self::H1(s) => s.proxy_tasks_enabled(), + Self::Subrequest(s) => s.proxy_tasks_enabled(), + Self::H2(_) => false, + Self::Custom(_) => false, + } + } + + /// Enable or disable the cancel-safe proxy task API for this session. + pub fn set_proxy_tasks_enabled(&mut self, enabled: bool) { + match self { + Self::H1(s) => s.set_proxy_tasks_enabled(enabled), + Self::Subrequest(s) => s.set_proxy_tasks_enabled(enabled), + Self::H2(_) => {} + Self::Custom(_) => {} + } + } + + /// Queue a downstream proxy task for cancel-safe writing. + /// + /// # Panics + /// Panics if called on a session that doesn't support the proxy task API. + /// Check [`supports_proxy_task_api`](Self::supports_proxy_task_api) first, + /// or use `write_response_header()` / `write_response_body()` for other + /// session types. + pub fn send_downstream_proxy_task(&mut self, task: HttpTask) { + match self { + Self::H1(s) => s.send_proxy_task(task), + Self::H2(_) => panic!("H2 proxy task API not yet implemented"), + Self::Subrequest(s) => s.send_proxy_task(task), + Self::Custom(_) => panic!("Custom proxy task API not yet implemented"), + } + } + + /// Check if there are pending downstream proxy tasks queued for writing. + /// + /// Returns false for sessions that don't support the proxy task API. + pub fn has_pending_downstream_proxy_tasks(&self) -> bool { + match self { + Self::H1(s) => s.has_pending_proxy_tasks(), + Self::H2(_) => false, // TODO: implement for H2 + Self::Subrequest(s) => s.has_pending_proxy_tasks(), + Self::Custom(_) => false, // TODO: implement for custom + } + } + + /// Write all queued downstream proxy tasks in a cancel-safe manner. + /// Returns `Ok(true)` if this was the end of the response stream. + /// + /// # Panics + /// Panics if called on a session that doesn't support the proxy task API. + /// Check [`supports_proxy_task_api`](Self::supports_proxy_task_api) first, + /// or use `write_response_header()` / `write_response_body()` for other + /// session types. + pub async fn write_downstream_proxy_tasks(&mut self) -> Result { + match self { + Self::H1(s) => s.write_proxy_tasks().await, + Self::H2(_) => panic!("H2 proxy task API not yet implemented"), + Self::Subrequest(s) => s.write_proxy_tasks().await, + Self::Custom(_) => panic!("Custom proxy task API not yet implemented"), + } + } } diff --git a/pingora-core/src/protocols/http/subrequest/server.rs b/pingora-core/src/protocols/http/subrequest/server.rs index c91dbf91..938c8c97 100644 --- a/pingora-core/src/protocols/http/subrequest/server.rs +++ b/pingora-core/src/protocols/http/subrequest/server.rs @@ -36,13 +36,14 @@ use bytes::Bytes; use http::HeaderValue; use http::{header, header::AsHeaderName, HeaderMap, Method}; use log::{debug, trace, warn}; -use pingora_error::{Error, ErrorType::*, OkOrErr, Result}; +use pingora_error::{Error, ErrorType::*, OkOrErr, OrErr, Result}; use pingora_http::{RequestHeader, ResponseHeader}; use pingora_timeout::timeout; +use std::collections::VecDeque; use std::time::Duration; use tokio::sync::{mpsc, oneshot}; -use super::body::{BodyReader, BodyWriter}; +use super::body::{BodyMode, BodyReader, BodyWriter, PREMATURE_BODY_END}; use crate::protocols::http::{ body_buffer::FixedBuffer, server::Session as GenericHttpSession, @@ -53,6 +54,47 @@ use crate::protocols::http::{ }; use crate::protocols::{Digest, SocketAddr}; +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] +enum StreamEndState { + #[default] + Open, + FinishRequired, + Finished, +} + +/// State for the cancel-safe proxy task write API. +#[derive(Default)] +struct ProxyTaskState { + /// Tasks remain queued until `Permit::send` commits them to the channel. + tasks: VecDeque, + /// Final `Done` is waiting for channel capacity. + finish_in_progress: bool, + /// End-of-stream observed from already-consumed tasks. + stream_end: StreamEndState, +} + +impl ProxyTaskState { + fn require_finish(&mut self) { + self.stream_end = StreamEndState::FinishRequired; + } + + fn mark_finished(&mut self) { + self.stream_end = StreamEndState::Finished; + } + + fn clear_stream_end(&mut self) { + self.stream_end = StreamEndState::Open; + } + + fn finish_required(&self) -> bool { + self.stream_end == StreamEndState::FinishRequired + } + + fn finished(&self) -> bool { + self.stream_end == StreamEndState::Finished + } +} + /// The HTTP server session pub struct HttpSession { // these are only options because we allow dropping them separately on shutdown @@ -76,6 +118,11 @@ pub struct HttpSession { // TODO: likely doesn't need to be a separate bool when/if moving away from dummy SessionV1 clear_request_body_headers: bool, digest: Option>, + /// Whether the cancel-safe proxy task API is enabled for this session. + /// Defaults to `false`. Toggle via [`set_proxy_tasks_enabled`](Self::set_proxy_tasks_enabled). + proxy_tasks_enabled: bool, + /// Cancel-safe proxy task state. + proxy_task_state: ProxyTaskState, } /// A handle to the subrequest session itself to interact or read from it. @@ -133,6 +180,8 @@ impl HttpSession { upgraded: false, clear_request_body_headers: false, digest: digest.map(Box::new), + proxy_tasks_enabled: false, + proxy_task_state: ProxyTaskState::default(), }, SubrequestHandle { tx: downstream_tx, @@ -310,12 +359,24 @@ impl HttpSession { // XXX: don't add additional downstream headers, unlike h1, subreq is mostly treated as a pipe - // Allow informational header (excluding 101) to pass through without affecting the state - // of the request + let mut upgrade_ok: Option = None; if header.status == 101 || !header.status.is_informational() { - // reset request body to done for incomplete upgrade handshakes - if let Some(upgrade_ok) = self.is_upgrade(&header) { - if upgrade_ok { + upgrade_ok = self.is_upgrade(&header); + } + + // TODO propagate h2 end + debug!("send response header (subrequest)"); + match self + .tx + .as_mut() + .expect("tx valid before shutdown") + .send(HttpTask::Header(header.clone(), false)) + .await + { + Ok(()) => { + self.init_body_writer(&header); + self.response_written = Some(*header); + if let Some(true) = upgrade_ok { debug!("ok upgrade handshake"); // For ws we use HTTP1_0 do_read_body_until_closed // @@ -342,25 +403,10 @@ impl HttpSession { // TODO: this has no effect resetting the body counter of TE chunked self.body_reader.convert_to_close_delimited(); } - } else { + } else if upgrade_ok == Some(false) { debug!("bad upgrade handshake!"); // continue to read body as-is, this is now just a regular request } - } - self.init_body_writer(&header); - } - - // TODO propagate h2 end - debug!("send response header (subrequest)"); - match self - .tx - .as_mut() - .expect("tx valid before shutdown") - .send(HttpTask::Header(header.clone(), false)) - .await - { - Ok(()) => { - self.response_written = Some(*header); Ok(()) } Err(e) => Error::e_because(WriteError, "writing response header", e), @@ -390,40 +436,12 @@ impl HttpSession { } fn init_body_writer(&mut self, header: &ResponseHeader) { - use http::StatusCode; - /* the following responses don't have body 204, 304, and HEAD */ - if matches!( - header.status, - StatusCode::NO_CONTENT | StatusCode::NOT_MODIFIED - ) || self.get_method() == Some(&Method::HEAD) - { - self.body_writer.init_content_length(0); - return; - } - - if header.status.is_informational() && header.status != StatusCode::SWITCHING_PROTOCOLS { - // 1xx response, not enough to init body - return; - } - - if self.is_upgrade(header) == Some(true) { - self.body_writer.init_close_delimited(); - } else if is_chunked_encoding_from_headers(&header.headers) { - // transfer-encoding takes priority over content-length - self.body_writer.init_close_delimited(); - } else { - let content_length = - header_value_content_length(header.headers.get(http::header::CONTENT_LENGTH)); - match content_length { - Some(length) => { - self.body_writer.init_content_length(length); - } - None => { - /* TODO: 1. connection: keepalive cannot be used, - 2. mark connection must be closed */ - self.body_writer.init_close_delimited(); - } - } + if let Some(mode) = body_mode_for_header( + header, + self.get_method(), + self.is_upgrade(header) == Some(true), + ) { + apply_body_mode(&mut self.body_writer, mode); } } @@ -800,6 +818,428 @@ impl HttpSession { ) .await } + + // Cancel-safe proxy task API. Unlike v1's partial `AsyncWrite` state + // machines, subrequest uses mpsc `reserve()` + synchronous `Permit::send`. + + /// Whether the cancel-safe proxy task API is enabled for this session. + pub fn proxy_tasks_enabled(&self) -> bool { + self.proxy_tasks_enabled + } + + /// Enable or disable the cancel-safe proxy task API for this session. + pub fn set_proxy_tasks_enabled(&mut self, enabled: bool) { + self.proxy_tasks_enabled = enabled; + } + + /// Queue a proxy task for cancel-safe writing. + pub fn send_proxy_task(&mut self, task: HttpTask) { + self.proxy_task_state.tasks.push_back(task); + } + + /// Whether there are pending proxy tasks queued for writing. + pub fn has_pending_proxy_tasks(&self) -> bool { + self.proxy_task_state.finish_in_progress + || self.proxy_task_state.finish_required() + || !self.proxy_task_state.tasks.is_empty() + } + + /// Write queued proxy tasks. Cancelling while waiting for channel capacity + /// leaves the current task in `proxy_task_state.tasks`. + pub async fn write_proxy_tasks(&mut self) -> Result { + loop { + if self.proxy_task_state.finished() { + self.proxy_task_state.tasks.clear(); + return Ok(true); + } + + if self.proxy_task_state.finish_in_progress { + self.finish_proxy_task().await?; + self.proxy_task_state.mark_finished(); + return Ok(true); + } + + let Some(front) = self.proxy_task_state.tasks.front() else { + break; + }; + + // Tasks with no underlying channel send: handle synchronously + // without reserving a permit. + match front { + HttpTask::Done => { + self.proxy_task_state.tasks.pop_front(); + self.proxy_task_state.require_finish(); + continue; + } + HttpTask::Failed(_) => { + let HttpTask::Failed(e) = self + .proxy_task_state + .tasks + .pop_front() + .expect("queue had a Failed task at the front") + else { + unreachable!() + }; + self.proxy_task_state.clear_stream_end(); + return Err(e); + } + _ => {} + } + + if let HttpTask::Header(_, header_end) = front { + let already_sent = match self.response_written.as_ref() { + Some(resp) => !resp.status.is_informational() || self.upgraded, + None => false, + }; + if already_sent { + warn!("Respond header is already sent, cannot send again (subrequest, proxy task)"); + if *header_end { + self.proxy_task_state.require_finish(); + } + self.proxy_task_state.tasks.pop_front(); + continue; + } + } + + if let Some((upgraded_task, body_end, no_data)) = match front { + HttpTask::Body(data, end) => { + Some((false, *end, data.as_ref().is_none_or(|d| d.is_empty()))) + } + HttpTask::UpgradedBody(data, end) => { + Some((true, *end, data.as_ref().is_none_or(|d| d.is_empty()))) + } + _ => None, + } { + if upgraded_task != self.upgraded { + if upgraded_task { + panic!("Unexpected UpgradedBody task received on un-upgraded downstream session (subrequest, proxy task)"); + } else { + panic!("Unexpected Body task received on upgraded downstream session (subrequest, proxy task)"); + } + } + + if body_end { + self.proxy_task_state.require_finish(); + } + + if no_data { + self.proxy_task_state.tasks.pop_front(); + continue; + } + + match self.body_writer.body_mode { + BodyMode::Complete(_) => { + self.proxy_task_state.tasks.pop_front(); + continue; + } + BodyMode::ContentLength(total, written) if written >= total => { + self.proxy_task_state.tasks.pop_front(); + continue; + } + BodyMode::ToSelect => { + self.proxy_task_state.tasks.pop_front(); + self.proxy_task_state.clear_stream_end(); + return Error::e_explain( + InternalError, + "subrequest body proxy task before header is sent", + ); + } + _ => {} + } + } + + // `reserve()` is the only cancellation point; the queued task is + // popped only after a permit is acquired. + let tx_ref = self + .tx + .as_ref() + .ok_or_else(|| Error::explain(InternalError, "subrequest tx already shut down"))?; + let permit = match self.write_timeout { + Some(t) => match timeout(t, tx_ref.reserve()).await { + Ok(res) => res.or_err(WriteError, "subrequest channel closed")?, + Err(_) => { + return Error::e_explain( + WriteTimedout, + format!("reserving subrequest channel slot, timeout: {t:?}"), + ); + } + }, + None => tx_ref + .reserve() + .await + .or_err(WriteError, "subrequest channel closed")?, + }; + + // From here until `permit.send`, no `.await`; dispatch is atomic. + let task = self + .proxy_task_state + .tasks + .pop_front() + .expect("queue non-empty"); + + match task { + HttpTask::Header(header, hdr_end) => { + if hdr_end { + self.proxy_task_state.require_finish(); + } + + let upgrade_ok = if header.status == 101 || !header.status.is_informational() { + let outcome = self.v1_inner.is_upgrade(&header); + let mode = body_mode_for_header( + &header, + self.v1_inner.get_method(), + outcome == Some(true), + ); + (outcome, mode) + } else { + (None, None) + }; + + permit.send(HttpTask::Header(header.clone(), false)); + + if let Some(mode) = upgrade_ok.1 { + apply_body_mode(&mut self.body_writer, mode); + } + self.response_written = Some(*header); + if let Some(true) = upgrade_ok.0 { + debug!("ok upgrade handshake (subrequest, proxy task)"); + self.upgraded = true; + if self.body_reader.need_init() { + self.init_body_reader(); + } else { + self.body_reader.convert_to_close_delimited(); + } + } else if upgrade_ok.0 == Some(false) { + debug!("bad upgrade handshake! (subrequest, proxy task)"); + } + } + + HttpTask::Body(data, end) => { + if end { + self.proxy_task_state.require_finish(); + } + dispatch_body_inline( + &mut self.body_writer, + &mut self.body_bytes_sent, + self.upgraded, + data, + /* upgraded_task = */ false, + permit, + )?; + } + HttpTask::UpgradedBody(data, end) => { + if end { + self.proxy_task_state.require_finish(); + } + dispatch_body_inline( + &mut self.body_writer, + &mut self.body_bytes_sent, + self.upgraded, + data, + /* upgraded_task = */ true, + permit, + )?; + } + + HttpTask::Trailer(trailers) => { + permit.send(HttpTask::Trailer(trailers)); + self.proxy_task_state.require_finish(); + } + + HttpTask::Done | HttpTask::Failed(_) => { + unreachable!("Done/Failed are handled above without reserving a permit") + } + } + } + + // Match `response_duplex_vec`: finish whenever any task signalled EOS. + if self.proxy_task_state.finish_required() || self.body_writer.finished() { + self.finish_proxy_task().await?; + self.proxy_task_state.mark_finished(); + return Ok(true); + } + + Ok(self.body_writer.finished()) + } + + async fn finish_proxy_task(&mut self) -> Result<()> { + if matches!( + &self.body_writer.body_mode, + BodyMode::Complete(_) | BodyMode::ToSelect + ) { + self.maybe_force_close_body_reader(); + self.proxy_task_state.finish_in_progress = false; + return Ok(()); + } + + if let BodyMode::ContentLength(total, written) = self.body_writer.body_mode { + if written < total { + self.body_writer.body_mode = BodyMode::Complete(written); + self.proxy_task_state.finish_in_progress = false; + self.proxy_task_state.clear_stream_end(); + return Error::e_explain( + PREMATURE_BODY_END, + format!( + "Content-length: {total} bytes written: {written} (subrequest, proxy task)" + ), + ); + } + } + + self.proxy_task_state.finish_in_progress = true; + self.dispatch_finish().await?; + self.proxy_task_state.finish_in_progress = false; + Ok(()) + } + + /// Dispatch the final `HttpTask::Done`, mirroring `body_writer::finish`. + async fn dispatch_finish(&mut self) -> Result<()> { + // Reserve cancel-safely, then synchronously update body_mode and send. + let tx_ref = self + .tx + .as_ref() + .ok_or_else(|| Error::explain(InternalError, "subrequest tx already shut down"))?; + let permit = match self.write_timeout { + Some(t) => match timeout(t, tx_ref.reserve()).await { + Ok(res) => res.or_err(WriteError, "subrequest channel closed")?, + Err(_) => { + return Error::e_explain( + WriteTimedout, + format!("reserving subrequest channel slot for finish, timeout: {t:?}"), + ); + } + }, + None => tx_ref + .reserve() + .await + .or_err(WriteError, "subrequest channel closed")?, + }; + + match self.body_writer.body_mode { + BodyMode::ContentLength(_total, written) => { + self.body_writer.body_mode = BodyMode::Complete(written); + permit.send(HttpTask::Done); + } + BodyMode::UntilClose(written) => { + self.body_writer.body_mode = BodyMode::Complete(written); + permit.send(HttpTask::Done); + } + BodyMode::Complete(_) => { + unreachable!("no-op body modes are handled before reserve") + } + BodyMode::ToSelect => { + unreachable!("no-op body modes are handled before reserve") + } + } + self.maybe_force_close_body_reader(); + Ok(()) + } +} + +fn body_mode_for_header( + header: &ResponseHeader, + method: Option<&Method>, + is_upgrade_ok: bool, +) -> Option { + use http::StatusCode; + if header.status.is_informational() && header.status != StatusCode::SWITCHING_PROTOCOLS { + return None; + } + if matches!( + header.status, + StatusCode::NO_CONTENT | StatusCode::NOT_MODIFIED + ) || method == Some(&Method::HEAD) + { + return Some(BodyMode::ContentLength(0, 0)); + } + if is_upgrade_ok || is_chunked_encoding_from_headers(&header.headers) { + Some(BodyMode::UntilClose(0)) + } else { + let content_length = + header_value_content_length(header.headers.get(http::header::CONTENT_LENGTH)); + match content_length { + Some(length) => Some(BodyMode::ContentLength(length, 0)), + None => Some(BodyMode::UntilClose(0)), + } + } +} + +fn apply_body_mode(body_writer: &mut BodyWriter, mode: BodyMode) { + match mode { + BodyMode::ContentLength(total, 0) => body_writer.init_content_length(total), + BodyMode::UntilClose(0) => body_writer.init_close_delimited(), + _ => body_writer.body_mode = mode, + } +} + +/// Body dispatch variant that avoids borrowing the whole session while a +/// channel permit borrows `self.tx`. +fn dispatch_body_inline( + body_writer: &mut BodyWriter, + body_bytes_sent: &mut usize, + upgraded: bool, + data: Option, + upgraded_task: bool, + permit: mpsc::Permit<'_, HttpTask>, +) -> Result<()> { + if upgraded_task != upgraded { + if upgraded_task { + panic!("Unexpected UpgradedBody task received on un-upgraded downstream session (subrequest, proxy task)"); + } else { + panic!("Unexpected Body task received on upgraded downstream session (subrequest, proxy task)"); + } + } + + let Some(d) = data else { + drop(permit); + return Ok(()); + }; + if d.is_empty() { + drop(permit); + return Ok(()); + } + + let (to_count, next_mode) = match &body_writer.body_mode { + BodyMode::ContentLength(total, written) => { + if written >= total { + drop(permit); + return Ok(()); + } + let remaining = *total - *written; + let to_write = if remaining < d.len() { + warn!("Trying to write data over content-length (subrequest, proxy task): {total}"); + remaining + } else { + d.len() + }; + ( + to_write, + BodyMode::ContentLength(*total, *written + to_write), + ) + } + BodyMode::UntilClose(written) => (d.len(), BodyMode::UntilClose(*written + d.len())), + BodyMode::Complete(_) => { + drop(permit); + return Ok(()); + } + BodyMode::ToSelect => { + drop(permit); + return Error::e_explain( + InternalError, + "subrequest body proxy task before header is sent", + ); + } + }; + + let to_send = if to_count < d.len() { + d.slice(..to_count) + } else { + d + }; + permit.send(HttpTask::Body(Some(to_send), false)); + body_writer.body_mode = next_mode; + *body_bytes_sent += to_count; + Ok(()) } #[cfg(test)] @@ -817,25 +1257,49 @@ mod tests_stream { let _ = env_logger::builder().is_test(true).try_init(); } + fn test_header(status: StatusCode) -> ResponseHeader { + ResponseHeader::build(status, None) + .expect("test status code should build a response header") + } + + fn recv_task(rx: &mut mpsc::Receiver) -> HttpTask { + rx.try_recv() + .expect("expected subrequest output task to be queued") + } + async fn session_from_input(input: &[u8]) -> (HttpSession, SubrequestHandle) { let mock_io = Builder::new().read(input).build(); let mut http_stream = GenericHttpSession::new_http1(Box::new(mock_io)); - http_stream.read_request().await.unwrap(); + http_stream + .read_request() + .await + .expect("test async operation should succeed"); let (mut http_stream, handle) = HttpSession::new_from_session(&http_stream); - http_stream.read_request().await.unwrap(); + http_stream + .read_request() + .await + .expect("test async operation should succeed"); (http_stream, handle) } - async fn build_upgrade_req(upgrade: &str, conn: &str) -> (HttpSession, SubrequestHandle) { + pub(super) async fn build_upgrade_req( + upgrade: &str, + conn: &str, + ) -> (HttpSession, SubrequestHandle) { let input = format!("GET / HTTP/1.1\r\nHost: pingora.org\r\nUpgrade: {upgrade}\r\nConnection: {conn}\r\n\r\n"); session_from_input(input.as_bytes()).await } - async fn build_req() -> (HttpSession, SubrequestHandle) { + pub(super) async fn build_req() -> (HttpSession, SubrequestHandle) { let input = "GET / HTTP/1.1\r\nHost: pingora.org\r\n\r\n".to_string(); session_from_input(input.as_bytes()).await } + pub(super) async fn build_head_req() -> (HttpSession, SubrequestHandle) { + let input = "HEAD / HTTP/1.1\r\nHost: pingora.org\r\n\r\n".to_string(); + session_from_input(input.as_bytes()).await + } + #[tokio::test] async fn read_basic() { init_log(); @@ -880,12 +1344,12 @@ mod tests_stream { async fn read_upgrade_req_with_1xx_response() { let (mut http_stream, _handle) = build_upgrade_req("websocket", "upgrade").await; assert!(http_stream.is_upgrade_req()); - let mut response = ResponseHeader::build(StatusCode::CONTINUE, None).unwrap(); + let mut response = test_header(StatusCode::CONTINUE); response.set_version(http::Version::HTTP_11); http_stream .write_response_header(Box::new(response)) .await - .unwrap(); + .expect("test operation should succeed"); // 100 won't affect body state assert!(http_stream.is_body_done()); } @@ -893,13 +1357,15 @@ mod tests_stream { #[tokio::test] async fn write() { let (mut http_stream, mut handle) = build_req().await; - let mut new_response = ResponseHeader::build(StatusCode::OK, None).unwrap(); - new_response.append_header("Foo", "Bar").unwrap(); + let mut new_response = test_header(StatusCode::OK); + new_response + .append_header("Foo", "Bar") + .expect("test operation should succeed"); http_stream .write_response_header_ref(&new_response) .await - .unwrap(); - match handle.rx.try_recv().unwrap() { + .expect("test operation should succeed"); + match recv_task(&mut handle.rx) { HttpTask::Header(header, end) => { assert_eq!(header.status, StatusCode::OK); assert_eq!(header.headers["foo"], "Bar"); @@ -912,12 +1378,12 @@ mod tests_stream { #[tokio::test] async fn write_informational() { let (mut http_stream, mut handle) = build_req().await; - let response_100 = ResponseHeader::build(StatusCode::CONTINUE, None).unwrap(); + let response_100 = test_header(StatusCode::CONTINUE); http_stream .write_response_header_ref(&response_100) .await - .unwrap(); - match handle.rx.try_recv().unwrap() { + .expect("test operation should succeed"); + match recv_task(&mut handle.rx) { HttpTask::Header(header, end) => { assert_eq!(header.status, StatusCode::CONTINUE); assert!(!end); @@ -925,12 +1391,12 @@ mod tests_stream { t => panic!("unexpected task {t:?}"), } - let response_200 = ResponseHeader::build(StatusCode::OK, None).unwrap(); + let response_200 = test_header(StatusCode::OK); http_stream .write_response_header_ref(&response_200) .await - .unwrap(); - match handle.rx.try_recv().unwrap() { + .expect("test operation should succeed"); + match recv_task(&mut handle.rx) { HttpTask::Header(header, end) => { assert_eq!(header.status, StatusCode::OK); assert!(!end); @@ -942,15 +1408,16 @@ mod tests_stream { #[tokio::test] async fn write_101_switching_protocol() { let (mut http_stream, mut handle) = build_upgrade_req("WebSocket", "Upgrade").await; - let mut response_101 = - ResponseHeader::build(StatusCode::SWITCHING_PROTOCOLS, None).unwrap(); - response_101.append_header("Foo", "Bar").unwrap(); + let mut response_101 = test_header(StatusCode::SWITCHING_PROTOCOLS); + response_101 + .append_header("Foo", "Bar") + .expect("test operation should succeed"); http_stream .write_response_header_ref(&response_101) .await - .unwrap(); + .expect("test operation should succeed"); - match handle.rx.try_recv().unwrap() { + match recv_task(&mut handle.rx) { HttpTask::Header(header, end) => { assert_eq!(header.status, StatusCode::SWITCHING_PROTOCOLS); assert!(!end); @@ -964,16 +1431,17 @@ mod tests_stream { .write_body(wire_body.clone()) .await .unwrap() - .unwrap(); + .expect("test operation should succeed"); assert_eq!(wire_body.len(), n); // this write should be ignored - let response_502 = ResponseHeader::build(StatusCode::BAD_GATEWAY, None).unwrap(); + let response_502 = ResponseHeader::build(StatusCode::BAD_GATEWAY, None) + .expect("test operation should succeed"); http_stream .write_response_header_ref(&response_502) .await - .unwrap(); + .expect("test operation should succeed"); - match handle.rx.try_recv().unwrap() { + match recv_task(&mut handle.rx) { HttpTask::Body(body, _end) => { assert_eq!(body.unwrap().len(), n); } @@ -989,12 +1457,14 @@ mod tests_stream { async fn write_body_cl() { let (mut http_stream, _handle) = build_req().await; let wire_body = Bytes::from(&b"a"[..]); - let mut new_response = ResponseHeader::build(StatusCode::OK, None).unwrap(); - new_response.append_header("Content-Length", "1").unwrap(); + let mut new_response = test_header(StatusCode::OK); + new_response + .append_header("Content-Length", "1") + .expect("test operation should succeed"); http_stream .write_response_header_ref(&new_response) .await - .unwrap(); + .expect("test operation should succeed"); assert_eq!( http_stream.body_writer.body_mode, BodyMode::ContentLength(1, 0) @@ -1003,29 +1473,37 @@ mod tests_stream { .write_body(wire_body.clone()) .await .unwrap() - .unwrap(); + .expect("test operation should succeed"); assert_eq!(wire_body.len(), n); - let n = http_stream.finish().await.unwrap().unwrap(); + let n = http_stream + .finish() + .await + .expect("test async operation should succeed") + .expect("test operation should succeed"); assert_eq!(wire_body.len(), n); } #[tokio::test] async fn write_body_until_close() { let (mut http_stream, _handle) = build_req().await; - let new_response = ResponseHeader::build(StatusCode::OK, None).unwrap(); + let new_response = test_header(StatusCode::OK); http_stream .write_response_header_ref(&new_response) .await - .unwrap(); + .expect("test operation should succeed"); assert_eq!(http_stream.body_writer.body_mode, BodyMode::UntilClose(0)); let wire_body = Bytes::from(&b"PAYLOAD"[..]); let n = http_stream .write_body(wire_body.clone()) .await .unwrap() - .unwrap(); + .expect("test operation should succeed"); assert_eq!(wire_body.len(), n); - let n = http_stream.finish().await.unwrap().unwrap(); + let n = http_stream + .finish() + .await + .expect("test async operation should succeed") + .expect("test operation should succeed"); assert_eq!(wire_body.len(), n); } @@ -1037,17 +1515,27 @@ mod tests_stream { let input3 = b"abc"; let mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build(); let mut http_stream = GenericHttpSession::new_http1(Box::new(mock_io)); - http_stream.read_request().await.unwrap(); + http_stream + .read_request() + .await + .expect("test async operation should succeed"); let (mut http_stream, handle) = HttpSession::new_from_session(&http_stream); - http_stream.read_request().await.unwrap(); + http_stream + .read_request() + .await + .expect("test async operation should succeed"); handle .tx .send(HttpTask::Body(Some(Bytes::from(&input3[..])), false)) .await - .unwrap(); + .expect("test operation should succeed"); assert_eq!(http_stream.get_path(), &b"/a?q=b%20c"[..]); - let res = http_stream.read_body().await.unwrap().unwrap(); + let res = http_stream + .read_body() + .await + .expect("test async operation should succeed") + .expect("test operation should succeed"); assert_eq!(res, &input3[..]); assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(3)); } @@ -1056,22 +1544,27 @@ mod tests_stream { async fn test_write_body_write_timeout() { let (mut http_stream, _handle) = build_req().await; http_stream.write_timeout = Some(Duration::from_millis(100)); - let mut new_response = ResponseHeader::build(StatusCode::OK, None).unwrap(); - new_response.append_header("Content-Length", "10").unwrap(); + let mut new_response = test_header(StatusCode::OK); + new_response + .append_header("Content-Length", "10") + .expect("test operation should succeed"); http_stream .write_response_header_ref(&new_response) .await - .unwrap(); + .expect("test operation should succeed"); let body_write_buf = Bytes::from(&b"abc"[..]); http_stream .write_body(body_write_buf.clone()) .await - .unwrap(); + .expect("test operation should succeed"); http_stream .write_body(body_write_buf.clone()) .await - .unwrap(); - http_stream.write_body(body_write_buf).await.unwrap(); + .expect("test operation should succeed"); + http_stream + .write_body(body_write_buf) + .await + .expect("test async operation should succeed"); // channel full let last_body = Bytes::from(&b"a"[..]); let res = http_stream.write_body(last_body).await; @@ -1081,8 +1574,11 @@ mod tests_stream { #[tokio::test] async fn test_write_continue_resp() { let (mut http_stream, mut handle) = build_req().await; - http_stream.write_continue_response().await.unwrap(); - match handle.rx.try_recv().unwrap() { + http_stream + .write_continue_response() + .await + .expect("test async operation should succeed"); + match recv_task(&mut handle.rx) { HttpTask::Header(header, end) => { assert_eq!(header.status, StatusCode::CONTINUE); assert!(!end); @@ -1095,7 +1591,10 @@ mod tests_stream { let mock_io = Builder::new().read(input).build(); let mut http_stream = GenericHttpSession::new_http1(Box::new(mock_io)); // Read the request in v1 inner session to set up headers properly - http_stream.read_request().await.unwrap(); + http_stream + .read_request() + .await + .expect("test async operation should succeed"); let (http_stream, handle) = HttpSession::new_from_session(&http_stream); (http_stream, handle) } @@ -1157,9 +1656,15 @@ mod tests_stream { async fn build_upgrade_req_with_body(header: &[u8]) -> (HttpSession, SubrequestHandle) { let mock_io = Builder::new().read(header).build(); let mut http_stream = GenericHttpSession::new_http1(Box::new(mock_io)); - http_stream.read_request().await.unwrap(); + http_stream + .read_request() + .await + .expect("test async operation should succeed"); let (mut http_stream, handle) = HttpSession::new_from_session(&http_stream); - http_stream.read_request().await.unwrap(); + http_stream + .read_request() + .await + .expect("test async operation should succeed"); (http_stream, handle) } @@ -1179,10 +1684,14 @@ mod tests_stream { .tx .send(HttpTask::Body(Some(Bytes::from(POST_BODY_DATA)), true)) .await - .unwrap(); + .expect("test operation should succeed"); let mut buf = vec![]; - while let Some(b) = http_stream.read_body_bytes().await.unwrap() { + while let Some(b) = http_stream + .read_body_bytes() + .await + .expect("test async operation should succeed") + { buf.put_slice(&b); } assert_eq!(buf, POST_BODY_DATA); @@ -1191,12 +1700,12 @@ mod tests_stream { assert!(http_stream.is_body_done()); - let mut response = ResponseHeader::build(StatusCode::SWITCHING_PROTOCOLS, None).unwrap(); + let mut response = test_header(StatusCode::SWITCHING_PROTOCOLS); response.set_version(http::Version::HTTP_11); http_stream .write_response_header(Box::new(response)) .await - .unwrap(); + .expect("test operation should succeed"); // body reader type switches assert!(!http_stream.is_body_done()); @@ -1206,15 +1715,1128 @@ mod tests_stream { .tx .send(HttpTask::Body(Some(Bytes::from(&ws_data[..])), false)) .await - .unwrap(); + .expect("test operation should succeed"); - let buf = http_stream.read_body_bytes().await.unwrap().unwrap(); + let buf = http_stream + .read_body_bytes() + .await + .expect("test async operation should succeed") + .expect("test operation should succeed"); assert_eq!(buf, ws_data.as_slice()); assert!(!http_stream.is_body_done()); // EOF ends body drop(handle.tx); - assert!(http_stream.read_body_bytes().await.unwrap().is_none()); + assert!(http_stream + .read_body_bytes() + .await + .expect("test async operation should succeed") + .is_none()); assert!(http_stream.is_body_done()); } } + +#[cfg(test)] +mod test_proxy_tasks { + //! Cancel-safe proxy task API tests for the subrequest server session. + + use super::tests_stream::{build_head_req, build_req, build_upgrade_req}; + use super::*; + use http::StatusCode; + + fn init_log() { + let _ = env_logger::builder().is_test(true).try_init(); + } + + fn test_header(status: StatusCode) -> ResponseHeader { + ResponseHeader::build(status, None) + .expect("test status code should build a response header") + } + + fn recv_task(rx: &mut mpsc::Receiver) -> HttpTask { + rx.try_recv() + .expect("expected subrequest output task to be queued") + } + + fn assert_rx_empty(rx: &mut mpsc::Receiver) { + assert!(matches!( + rx.try_recv(), + Err(mpsc::error::TryRecvError::Empty) + )); + } + + /// Dropping a blocked `send` does not deliver its value. + #[tokio::test(start_paused = true)] + async fn test_tokio_mpsc_send_cancel_drops_value() { + let (tx, mut rx) = mpsc::channel::(1); + tx.send(1) + .await + .expect("test async operation should succeed"); + let send_fut = tx.send(2); + tokio::pin!(send_fut); + tokio::select! { + biased; + _ = tokio::time::sleep(Duration::from_millis(10)) => {} + _ = &mut send_fut => panic!("expected the timer to win"), + }; + assert_eq!(rx.recv().await, Some(1)); + assert_eq!(rx.try_recv(), Err(mpsc::error::TryRecvError::Empty)); + } + + /// Dropping a blocked `reserve` does not consume capacity. + #[tokio::test(start_paused = true)] + async fn test_tokio_mpsc_reserve_cancel_releases_slot() { + let (tx, mut rx) = mpsc::channel::(1); + tx.send(1) + .await + .expect("test async operation should succeed"); + let reserve_fut = tx.reserve(); + tokio::pin!(reserve_fut); + tokio::select! { + biased; + _ = tokio::time::sleep(Duration::from_millis(10)) => {} + _ = &mut reserve_fut => panic!("expected the timer to win"), + }; + assert_eq!(rx.recv().await, Some(1)); + assert_eq!(rx.try_recv(), Err(mpsc::error::TryRecvError::Empty)); + } + + #[tokio::test] + async fn test_send_proxy_task_and_write() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + assert!(session.proxy_tasks_enabled()); + + let mut header = test_header(StatusCode::OK); + header + .insert_header("content-length", "5") + .expect("test operation should succeed"); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("hello")), true)); + + let end = session + .write_proxy_tasks() + .await + .expect("test async operation should succeed"); + assert!(end); + assert!(!session.has_pending_proxy_tasks()); + assert_eq!(session.body_bytes_sent(), 5); + + match recv_task(&mut handle.rx) { + HttpTask::Header(h, false) => assert_eq!(h.status, StatusCode::OK), + t => panic!("expected Header, got {t:?}"), + } + match recv_task(&mut handle.rx) { + HttpTask::Body(Some(b), false) => assert_eq!(&b[..], b"hello"), + t => panic!("expected Body, got {t:?}"), + } + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Done)); + assert_rx_empty(&mut handle.rx); + } + + #[tokio::test] + async fn test_informational_head_does_not_init_body_writer() { + init_log(); + + let (mut regular, mut regular_handle) = build_head_req().await; + regular + .write_response_header(Box::new(test_header(StatusCode::CONTINUE))) + .await + .expect("regular informational header write should succeed"); + assert_eq!(regular.body_writer.body_mode, BodyMode::ToSelect); + assert!(matches!( + recv_task(&mut regular_handle.rx), + HttpTask::Header(..) + )); + + let (mut proxy, mut proxy_handle) = build_head_req().await; + proxy.set_proxy_tasks_enabled(true); + proxy.send_proxy_task(HttpTask::Header( + Box::new(test_header(StatusCode::CONTINUE)), + false, + )); + assert!(!proxy + .write_proxy_tasks() + .await + .expect("proxy task write should succeed")); + assert_eq!(proxy.body_writer.body_mode, BodyMode::ToSelect); + assert!(matches!( + recv_task(&mut proxy_handle.rx), + HttpTask::Header(..) + )); + } + + #[tokio::test(start_paused = true)] + async fn test_proxy_task_with_timeout() { + init_log(); + // Do not drain `handle.rx` before the first write; the 5th response task blocks on capacity. + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + session.set_write_timeout(Some(Duration::from_millis(50))); + + let header = test_header(StatusCode::OK); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + for i in 0..5 { + session.send_proxy_task(HttpTask::Body( + Some(Bytes::from(format!("body-{i}"))), + i == 4, + )); + } + + let err = session + .write_proxy_tasks() + .await + .expect_err("full subrequest output channel should time out"); + assert_eq!(err.etype(), &WriteTimedout); + assert!(session.has_pending_proxy_tasks()); + + let mut delivered = Vec::new(); + while let Ok(task) = handle.rx.try_recv() { + delivered.push(task); + } + assert_eq!(delivered.len(), 4); + assert!(matches!(delivered[0], HttpTask::Header(..))); + + session.set_write_timeout(None); + let end = session + .write_proxy_tasks() + .await + .expect("retry after freeing channel capacity should complete"); + assert!(end); + while let Ok(task) = handle.rx.try_recv() { + delivered.push(task); + } + + assert_eq!(delivered.len(), 7); + assert!(matches!(delivered.last(), Some(HttpTask::Done))); + let body_count = delivered + .iter() + .filter(|t| matches!(t, HttpTask::Body(..))) + .count(); + assert_eq!(body_count, 5); + } + + #[tokio::test] + async fn test_proxy_task_channel_closed_errors() { + init_log(); + let (mut session, handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + drop(handle.rx); + + session.send_proxy_task(HttpTask::Header( + Box::new(test_header(StatusCode::OK)), + false, + )); + let err = session + .write_proxy_tasks() + .await + .expect_err("closed subrequest output channel should error"); + assert_eq!(err.etype(), &WriteError); + assert!(session.has_pending_proxy_tasks()); + } + + /// Repeatedly cancel while blocked on channel capacity, then verify the + /// receiver sees each queued task exactly once and in order. + #[tokio::test(start_paused = true)] + async fn test_proxy_task_cancel_safety() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let mut header = test_header(StatusCode::OK); + header + .insert_header("content-length", "5") + .expect("test operation should succeed"); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + for i in 0..4 { + session.send_proxy_task(HttpTask::Body( + Some(Bytes::from(vec![b'A' + i as u8; 1])), + false, + )); + } + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("E")), true)); + + let mut cancel_count = 0; + let mut delivered: Vec = Vec::new(); + loop { + if !session.has_pending_proxy_tasks() { + break; + } + tokio::select! { + biased; + _ = tokio::time::sleep(Duration::from_millis(5)) => { + cancel_count += 1; + while let Ok(task) = handle.rx.try_recv() { + delivered.push(task); + } + } + result = session.write_proxy_tasks() => { + result.expect("test operation should succeed"); + } + } + } + + assert!( + cancel_count >= 1, + "expected at least one cancellation during cancel-safe write, got {cancel_count}" + ); + + assert_eq!(session.proxy_task_state.tasks.len(), 0); + + while let Ok(task) = handle.rx.try_recv() { + delivered.push(task); + } + + assert!(matches!(delivered[0], HttpTask::Header(_, false))); + let mut body_bytes = Vec::new(); + let mut saw_done = false; + for task in &delivered[1..] { + match task { + HttpTask::Body(Some(b), false) => body_bytes.extend_from_slice(b), + HttpTask::Done => { + assert!(!saw_done, "Done delivered more than once"); + saw_done = true; + } + t => panic!("unexpected task in delivery: {t:?}"), + } + } + assert!(saw_done, "expected Done to be delivered"); + assert_eq!( + body_bytes, b"ABCDE", + "body chunks must arrive exactly once, in order" + ); + + assert_eq!(session.body_bytes_sent(), 5); + } + + /// `was_upgraded()` must remain false if the 101 send is cancelled before + /// it reaches the subrequest channel. + #[tokio::test(start_paused = true)] + async fn test_proxy_task_upgrade_consistency() { + init_log(); + let (mut session, mut handle) = build_upgrade_req("websocket", "Upgrade").await; + assert!(session.is_upgrade_req()); + session.set_proxy_tasks_enabled(true); + + // Four 1xx headers fill the upstream channel; the 101 then blocks. + for _ in 0..4 { + session.send_proxy_task(HttpTask::Header( + Box::new(test_header(StatusCode::CONTINUE)), + false, + )); + } + let mut h101 = test_header(StatusCode::SWITCHING_PROTOCOLS); + h101.set_version(http::Version::HTTP_11); + h101.insert_header("upgrade", "websocket") + .expect("test operation should succeed"); + h101.insert_header("connection", "Upgrade") + .expect("test operation should succeed"); + session.send_proxy_task(HttpTask::Header(Box::new(h101), false)); + + tokio::select! { + biased; + _ = tokio::time::sleep(Duration::from_millis(5)) => {} + _ = session.write_proxy_tasks() => panic!("expected reserve to be cancelled"), + }; + + assert!( + !session.was_upgraded(), + "was_upgraded must remain false until the 101 send actually completes" + ); + + for _ in 0..4 { + recv_task(&mut handle.rx); + } + session + .write_proxy_tasks() + .await + .expect("test async operation should succeed"); + assert!( + session.was_upgraded(), + "after the 101 send completes, was_upgraded must be true" + ); + } + + /// Same upgrade consistency check for the regular `write_response_header` + /// path, which also awaits on the subrequest output channel. + #[tokio::test(start_paused = true)] + async fn test_write_response_header_upgrade_cancel_consistency() { + init_log(); + let (mut session, mut handle) = build_upgrade_req("websocket", "Upgrade").await; + + for _ in 0..4 { + session + .write_response_header(Box::new(test_header(StatusCode::CONTINUE))) + .await + .expect("test operation should succeed"); + } + + let mut h101 = test_header(StatusCode::SWITCHING_PROTOCOLS); + h101.set_version(http::Version::HTTP_11); + h101.insert_header("upgrade", "websocket") + .expect("test operation should succeed"); + h101.insert_header("connection", "Upgrade") + .expect("test operation should succeed"); + + tokio::select! { + biased; + _ = tokio::time::sleep(Duration::from_millis(5)) => {} + _ = session.write_response_header(Box::new(h101.clone())) => { + panic!("expected header send to be cancelled") + } + }; + assert!(!session.was_upgraded()); + assert_eq!(session.body_writer.body_mode, BodyMode::ToSelect); + + for _ in 0..4 { + recv_task(&mut handle.rx); + } + session + .write_response_header(Box::new(h101)) + .await + .expect("test async operation should succeed"); + assert!(session.was_upgraded()); + } + + /// Trailers are dispatched correctly through `write_proxy_tasks`. + /// Matching regular `response_duplex_vec`, a final `Done` follows Trailer. + #[tokio::test] + async fn test_proxy_task_trailers() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let header = test_header(StatusCode::OK); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("hi")), false)); + let mut trailers = http::HeaderMap::new(); + trailers.insert("x-final", http::HeaderValue::from_static("done")); + session.send_proxy_task(HttpTask::Trailer(Some(Box::new(trailers)))); + + let end = session + .write_proxy_tasks() + .await + .expect("test async operation should succeed"); + assert!(end); + + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Header(..))); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Body(..))); + match recv_task(&mut handle.rx) { + HttpTask::Trailer(Some(t)) => { + assert_eq!( + t.get("x-final").expect("test trailer should be present"), + "done" + ); + } + t => panic!("expected Trailer, got {t:?}"), + } + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Done)); + assert_rx_empty(&mut handle.rx); + } + + #[tokio::test] + async fn test_proxy_task_trailer_before_content_length_complete_errors() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let mut header = test_header(StatusCode::OK); + header + .insert_header("content-length", "5") + .expect("test content-length header is valid"); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + let mut trailers = HeaderMap::new(); + trailers.insert("x-final", http::HeaderValue::from_static("done")); + session.send_proxy_task(HttpTask::Trailer(Some(Box::new(trailers)))); + + let err = session + .write_proxy_tasks() + .await + .expect_err("trailers before content-length body completion should error"); + assert_eq!(err.etype(), &PREMATURE_BODY_END); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Header(..))); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Trailer(..))); + assert_rx_empty(&mut handle.rx); + } + + #[tokio::test(start_paused = true)] + async fn test_proxy_task_trailer_cancel_safety() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + for _ in 0..4 { + session.send_proxy_task(HttpTask::Header( + Box::new(test_header(StatusCode::CONTINUE)), + false, + )); + } + let mut trailers = HeaderMap::new(); + trailers.insert("x-final", http::HeaderValue::from_static("done")); + session.send_proxy_task(HttpTask::Trailer(Some(Box::new(trailers)))); + + tokio::select! { + biased; + _ = tokio::time::sleep(Duration::from_millis(5)) => {} + _ = session.write_proxy_tasks() => panic!("expected trailer reserve to be cancelled"), + }; + + let mut prefix = Vec::new(); + while let Ok(task) = handle.rx.try_recv() { + prefix.push(task); + } + assert_eq!(prefix.len(), 4); + assert!(prefix.iter().all(|t| matches!(t, HttpTask::Header(..)))); + let end = session + .write_proxy_tasks() + .await + .expect("resume after trailer cancellation should complete"); + assert!(end); + + match recv_task(&mut handle.rx) { + HttpTask::Trailer(Some(t)) => { + assert_eq!(t.get("x-final").expect("trailer present"), "done") + } + t => panic!("expected Trailer after resume, got {t:?}"), + } + assert_rx_empty(&mut handle.rx); + } + + #[tokio::test(start_paused = true)] + async fn test_proxy_task_trailer_cancel_safety_after_body() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let header = test_header(StatusCode::OK); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + for _ in 0..3 { + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("x")), false)); + } + let mut trailers = HeaderMap::new(); + trailers.insert("x-final", http::HeaderValue::from_static("done")); + session.send_proxy_task(HttpTask::Trailer(Some(Box::new(trailers)))); + + tokio::select! { + biased; + _ = tokio::time::sleep(Duration::from_millis(5)) => {} + _ = session.write_proxy_tasks() => panic!("expected trailer reserve to be cancelled"), + }; + + let mut prefix = Vec::new(); + while let Ok(task) = handle.rx.try_recv() { + prefix.push(task); + } + assert_eq!(prefix.len(), 4); + assert!(matches!(prefix[0], HttpTask::Header(..))); + assert!(prefix[1..].iter().all(|t| matches!(t, HttpTask::Body(..)))); + let end = session + .write_proxy_tasks() + .await + .expect("resume after body trailer cancellation should complete"); + assert!(end); + + match recv_task(&mut handle.rx) { + HttpTask::Trailer(Some(t)) => { + assert_eq!(t.get("x-final").expect("trailer present"), "done") + } + t => panic!("expected Trailer after resume, got {t:?}"), + } + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Done)); + assert_rx_empty(&mut handle.rx); + } + + /// `body_bytes_sent` is only incremented after the synchronous + /// `Permit::send`, not on a cancelled `reserve().await`. + #[tokio::test(start_paused = true)] + async fn test_proxy_task_body_counter_no_double_count() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let mut header = test_header(StatusCode::OK); + header + .insert_header("content-length", "12") + .expect("test operation should succeed"); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("AAAA")), false)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("BBBB")), false)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("CCCC")), true)); + + let mut received_body_bytes = Vec::new(); + loop { + if !session.has_pending_proxy_tasks() { + break; + } + tokio::select! { + biased; + _ = tokio::time::sleep(Duration::from_millis(5)) => { + while let Ok(task) = handle.rx.try_recv() { + if let HttpTask::Body(Some(b), _) = &task { + received_body_bytes.extend_from_slice(b); + } + } + assert!( + session.body_bytes_sent() <= received_body_bytes.len(), + "body_bytes_sent ({}) must not exceed bytes actually delivered ({})", + session.body_bytes_sent(), + received_body_bytes.len(), + ); + } + result = session.write_proxy_tasks() => { + result.expect("test operation should succeed"); + } + } + } + + while let Ok(task) = handle.rx.try_recv() { + if let HttpTask::Body(Some(b), _) = &task { + received_body_bytes.extend_from_slice(b); + } + } + + assert_eq!(session.body_bytes_sent(), 12); + assert_eq!(&received_body_bytes[..], b"AAAABBBBCCCC"); + } + + /// Cancelling while reserving capacity for the final Done must leave the + /// finish operation resumable. + #[tokio::test(start_paused = true)] + async fn test_proxy_task_finish_cancel_safety() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let header = test_header(StatusCode::OK); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("a")), false)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("b")), false)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("c")), true)); + + // Header + three bodies fill the 4-slot channel. The final Done + // reserve blocks and is cancelled. + tokio::select! { + biased; + _ = tokio::time::sleep(Duration::from_millis(5)) => {} + _ = session.write_proxy_tasks() => panic!("expected finish reserve to be cancelled"), + }; + assert!(session.proxy_task_state.finish_in_progress); + + let first = recv_task(&mut handle.rx); + assert!(matches!(first, HttpTask::Header(..))); + + // Only one slot is available. Resuming must emit exactly one Done and + // return without trying to reserve a second slot. + let end = tokio::time::timeout(Duration::from_millis(5), session.write_proxy_tasks()) + .await + .expect("resume should not need a second channel slot") + .expect("test operation should succeed"); + assert!(end); + + let mut delivered = Vec::new(); + while let Ok(task) = handle.rx.try_recv() { + delivered.push(task); + } + assert_eq!(delivered.len(), 4); + assert!(matches!(delivered.last(), Some(HttpTask::Done))); + assert_eq!( + delivered + .iter() + .filter(|t| matches!(t, HttpTask::Done)) + .count(), + 1 + ); + } + + #[tokio::test] + async fn test_proxy_task_done_only_noops_without_channel() { + init_log(); + let (mut session, _handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + session.send_proxy_task(HttpTask::Done); + session.shutdown(); + + let end = session + .write_proxy_tasks() + .await + .expect("test async operation should succeed"); + assert!(end); + } + + #[tokio::test] + async fn test_proxy_task_done_only_noops_with_live_channel() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + session.send_proxy_task(HttpTask::Done); + + let end = session + .write_proxy_tasks() + .await + .expect("Done-only proxy task should complete without channel output"); + assert!(end); + assert!(!session.has_pending_proxy_tasks()); + assert_rx_empty(&mut handle.rx); + } + + #[tokio::test] + async fn test_proxy_task_header_only_end() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let header = test_header(StatusCode::NO_CONTENT); + session.send_proxy_task(HttpTask::Header(Box::new(header), true)); + + let end = session + .write_proxy_tasks() + .await + .expect("test async operation should succeed"); + assert!(end); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Header(..))); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Done)); + assert_rx_empty(&mut handle.rx); + } + + #[tokio::test] + async fn test_proxy_task_head_response_drops_body() { + init_log(); + let (mut session, mut handle) = build_head_req().await; + session.set_proxy_tasks_enabled(true); + + let mut header = test_header(StatusCode::OK); + header + .insert_header("content-length", "10") + .expect("test content-length header is valid"); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("not-sent")), true)); + + let end = session + .write_proxy_tasks() + .await + .expect("HEAD proxy task response should complete"); + assert!(end); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Header(..))); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Done)); + assert_rx_empty(&mut handle.rx); + assert_eq!(session.body_bytes_sent(), 0); + } + + #[tokio::test(start_paused = true)] + async fn test_proxy_task_duplicate_final_header_does_not_reserve() { + init_log(); + let (mut session, _handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let header = test_header(StatusCode::OK); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + session + .write_proxy_tasks() + .await + .expect("test async operation should succeed"); + + for _ in 0..3 { + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("x")), false)); + } + let duplicate = test_header(StatusCode::OK); + session.send_proxy_task(HttpTask::Header(Box::new(duplicate), false)); + + tokio::time::timeout(Duration::from_millis(5), session.write_proxy_tasks()) + .await + .expect("duplicate final header should be dropped without reserving") + .expect("test operation should succeed"); + assert!(!session.has_pending_proxy_tasks()); + } + + #[tokio::test] + async fn test_proxy_task_duplicate_final_header_preserves_end() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let header = test_header(StatusCode::OK); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + session + .write_proxy_tasks() + .await + .expect("test async operation should succeed"); + recv_task(&mut handle.rx); + + let duplicate = test_header(StatusCode::OK); + session.send_proxy_task(HttpTask::Header(Box::new(duplicate), true)); + + let end = session + .write_proxy_tasks() + .await + .expect("test async operation should succeed"); + assert!(end); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Done)); + assert_rx_empty(&mut handle.rx); + } + + #[tokio::test] + async fn test_proxy_task_failed_propagates_without_sending_later_tasks() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + session.send_proxy_task(HttpTask::Failed(Error::explain(InternalError, "boom"))); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("late")), true)); + + let err = session + .write_proxy_tasks() + .await + .expect_err("Failed proxy task should propagate error"); + assert_eq!(err.etype(), &InternalError); + assert!(session.has_pending_proxy_tasks()); + assert_rx_empty(&mut handle.rx); + } + + #[tokio::test] + async fn test_proxy_task_failed_clears_sticky_eos() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let header = test_header(StatusCode::OK); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + session.send_proxy_task(HttpTask::Body(None, true)); + session.send_proxy_task(HttpTask::Failed(Error::explain(InternalError, "boom"))); + + let err = session + .write_proxy_tasks() + .await + .expect_err("Failed after EOS should still propagate error"); + assert_eq!(err.etype(), &InternalError); + assert!(!session.has_pending_proxy_tasks()); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Header(..))); + assert_rx_empty(&mut handle.rx); + + let end = session + .write_proxy_tasks() + .await + .expect("retry after Failed should not emit stale Done"); + assert!(!end); + assert_rx_empty(&mut handle.rx); + } + + #[tokio::test] + async fn test_proxy_task_body_before_header_errors() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("body")), true)); + + let err = session + .write_proxy_tasks() + .await + .expect_err("body before response header should be rejected"); + assert_eq!(err.etype(), &InternalError); + assert!(!session.has_pending_proxy_tasks()); + assert_rx_empty(&mut handle.rx); + } + + #[tokio::test(start_paused = true)] + async fn test_proxy_task_none_body_does_not_reserve() { + init_log(); + let (mut session, _handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let header = test_header(StatusCode::OK); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + for _ in 0..3 { + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("x")), false)); + } + session.send_proxy_task(HttpTask::Body(None, false)); + + tokio::time::timeout(Duration::from_millis(5), session.write_proxy_tasks()) + .await + .expect("Body(None) should not reserve channel capacity") + .expect("test operation should succeed"); + assert!(!session.has_pending_proxy_tasks()); + } + + #[tokio::test(start_paused = true)] + async fn test_proxy_task_no_data_eos_survives_cancellation() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let header = test_header(StatusCode::OK); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + for _ in 0..3 { + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("x")), false)); + } + session.send_proxy_task(HttpTask::Body(None, true)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::new()), true)); + // This later body blocks on the full channel after the no-data EOS + // task has been consumed. The sticky EOS flag must survive that cancel. + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("y")), false)); + + tokio::select! { + biased; + _ = tokio::time::sleep(Duration::from_millis(5)) => {} + _ = session.write_proxy_tasks() => panic!("expected reserve after no-data EOS to be cancelled"), + }; + + let mut prefix = Vec::new(); + while let Ok(task) = handle.rx.try_recv() { + prefix.push(task); + } + assert_eq!(prefix.len(), 4); + assert!(matches!(prefix[0], HttpTask::Header(..))); + assert!(prefix[1..].iter().all(|t| matches!(t, HttpTask::Body(..)))); + let end = session + .write_proxy_tasks() + .await + .expect("resume after no-data EOS cancellation should complete"); + assert!(end); + + let mut delivered = Vec::new(); + while let Ok(task) = handle.rx.try_recv() { + delivered.push(task); + } + assert_eq!(delivered.len(), 2); + match &delivered[0] { + HttpTask::Body(Some(b), false) => assert_eq!(&b[..], b"y"), + t => panic!("expected Body(y) after resume, got {t:?}"), + } + assert!(matches!(delivered[1], HttpTask::Done)); + assert_eq!(session.body_bytes_sent(), 4); + } + + #[tokio::test] + async fn test_proxy_task_content_length_overrun_truncates() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let mut header = test_header(StatusCode::OK); + header + .insert_header("content-length", "3") + .expect("test operation should succeed"); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("abcdef")), true)); + + let end = session + .write_proxy_tasks() + .await + .expect("test async operation should succeed"); + assert!(end); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Header(..))); + match recv_task(&mut handle.rx) { + HttpTask::Body(Some(b), false) => assert_eq!(&b[..], b"abc"), + t => panic!("expected truncated Body, got {t:?}"), + } + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Done)); + assert_rx_empty(&mut handle.rx); + assert_eq!(session.body_bytes_sent(), 3); + } + + #[tokio::test] + async fn test_proxy_task_exact_content_length_without_end_sends_done() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let mut header = test_header(StatusCode::OK); + header + .insert_header("content-length", "3") + .expect("test content-length header is valid"); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("abc")), false)); + + let end = session + .write_proxy_tasks() + .await + .expect("exact content-length proxy task response should complete"); + assert!(end); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Header(..))); + match recv_task(&mut handle.rx) { + HttpTask::Body(Some(b), false) => assert_eq!(&b[..], b"abc"), + t => panic!("expected exact content-length Body, got {t:?}"), + } + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Done)); + assert_rx_empty(&mut handle.rx); + } + + #[tokio::test] + async fn test_proxy_task_late_tasks_after_finished_are_dropped() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let mut header = test_header(StatusCode::OK); + header + .insert_header("content-length", "3") + .expect("test content-length header is valid"); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("abc")), true)); + session + .write_proxy_tasks() + .await + .expect("initial response should complete"); + + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Header(..))); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Body(..))); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Done)); + + let mut trailers = HeaderMap::new(); + trailers.insert("x-late", http::HeaderValue::from_static("ignored")); + session.send_proxy_task(HttpTask::Trailer(Some(Box::new(trailers)))); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("late")), true)); + session.send_proxy_task(HttpTask::Failed(Error::explain(InternalError, "late"))); + + let end = session + .write_proxy_tasks() + .await + .expect("late tasks after finished stream should be dropped"); + assert!(end); + assert_rx_empty(&mut handle.rx); + assert!(!session.has_pending_proxy_tasks()); + } + + #[tokio::test] + async fn test_proxy_task_chunked_header_uses_until_close() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let mut header = test_header(StatusCode::OK); + header + .insert_header("transfer-encoding", "chunked") + .expect("test transfer-encoding header is valid"); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("chunk")), true)); + + let end = session + .write_proxy_tasks() + .await + .expect("chunked proxy task response should complete"); + assert!(end); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Header(..))); + match recv_task(&mut handle.rx) { + HttpTask::Body(Some(b), false) => assert_eq!(&b[..], b"chunk"), + t => panic!("expected chunked body task, got {t:?}"), + } + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Done)); + assert_eq!(session.body_bytes_sent(), 5); + } + + #[tokio::test] + async fn test_proxy_task_premature_content_length_errors_before_done() { + init_log(); + let (mut session, mut handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + let mut header = test_header(StatusCode::OK); + header + .insert_header("content-length", "5") + .expect("test operation should succeed"); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("hi")), true)); + + let err = session + .write_proxy_tasks() + .await + .expect_err("premature content-length should error before Done"); + assert_eq!(err.etype(), &PREMATURE_BODY_END); + assert_eq!(session.body_writer.body_mode, BodyMode::Complete(2)); + assert!(!session.has_pending_proxy_tasks()); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Header(..))); + match recv_task(&mut handle.rx) { + HttpTask::Body(Some(b), false) => assert_eq!(&b[..], b"hi"), + t => panic!("expected body before premature end, got {t:?}"), + } + assert_rx_empty(&mut handle.rx); + } + + #[tokio::test] + #[should_panic( + expected = "Unexpected UpgradedBody task received on un-upgraded downstream session" + )] + async fn test_upgraded_body_on_non_upgraded_session_panics() { + init_log(); + let (mut session, _handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + assert!(!session.was_upgraded()); + + let header = test_header(StatusCode::OK); + session.send_proxy_task(HttpTask::Header(Box::new(header), false)); + session.send_proxy_task(HttpTask::UpgradedBody(Some(Bytes::from("ws")), true)); + + let _ = session.write_proxy_tasks().await; + } + + #[tokio::test] + #[should_panic(expected = "Unexpected Body task received on upgraded downstream session")] + async fn test_body_on_upgraded_session_panics() { + init_log(); + let (mut session, _handle) = build_upgrade_req("websocket", "Upgrade").await; + session.set_proxy_tasks_enabled(true); + + let mut h101 = test_header(StatusCode::SWITCHING_PROTOCOLS); + h101.set_version(http::Version::HTTP_11); + h101.insert_header("upgrade", "websocket") + .expect("test operation should succeed"); + h101.insert_header("connection", "Upgrade") + .expect("test operation should succeed"); + session.send_proxy_task(HttpTask::Header(Box::new(h101), false)); + session + .write_proxy_tasks() + .await + .expect("test async operation should succeed"); + + session.send_proxy_task(HttpTask::Body(Some(Bytes::from("plain")), true)); + let _ = session.write_proxy_tasks().await; + } + + #[tokio::test] + async fn test_proxy_task_upgraded_body_happy_path() { + init_log(); + let (mut session, mut handle) = build_upgrade_req("websocket", "Upgrade").await; + session.set_proxy_tasks_enabled(true); + + let mut h101 = test_header(StatusCode::SWITCHING_PROTOCOLS); + h101.set_version(http::Version::HTTP_11); + h101.insert_header("upgrade", "websocket") + .expect("test upgrade header is valid"); + h101.insert_header("connection", "Upgrade") + .expect("test connection header is valid"); + session.send_proxy_task(HttpTask::Header(Box::new(h101), false)); + session.send_proxy_task(HttpTask::UpgradedBody(Some(Bytes::from("ws")), true)); + + let end = session + .write_proxy_tasks() + .await + .expect("upgraded proxy task response should complete"); + assert!(end); + assert!(session.was_upgraded()); + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Header(..))); + match recv_task(&mut handle.rx) { + HttpTask::Body(Some(b), false) => assert_eq!(&b[..], b"ws"), + t => panic!("expected upgraded body task, got {t:?}"), + } + assert!(matches!(recv_task(&mut handle.rx), HttpTask::Done)); + assert_eq!(session.body_bytes_sent(), 2); + } + + #[tokio::test] + #[should_panic( + expected = "Unexpected UpgradedBody task received on un-upgraded downstream session" + )] + async fn test_upgraded_body_on_non_upgraded_session_panics_while_full() { + init_log(); + let (mut session, _handle) = build_req().await; + session.set_proxy_tasks_enabled(true); + + for _ in 0..4 { + session.send_proxy_task(HttpTask::Header( + Box::new(test_header(StatusCode::CONTINUE)), + false, + )); + } + session.send_proxy_task(HttpTask::UpgradedBody(Some(Bytes::from("ws")), true)); + let _ = session.write_proxy_tasks().await; + } +} diff --git a/pingora-core/src/protocols/http/v1/body.rs b/pingora-core/src/protocols/http/v1/body.rs index 72899257..61872af6 100644 --- a/pingora-core/src/protocols/http/v1/body.rs +++ b/pingora-core/src/protocols/http/v1/body.rs @@ -20,9 +20,14 @@ use pingora_error::{ OrErr, Result, }; use std::fmt::Debug; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use crate::protocols::l4::stream::AsyncWriteVec; +use crate::protocols::l4::stream::{ + async_write_vec::{poll_write_all_buf, poll_write_vec_all_buf}, + AsyncWriteVec, +}; use crate::utils::BufRef; // TODO: make this dynamically adjusted @@ -905,14 +910,188 @@ pub enum BodyMode { type BM = BodyMode; +// ============================================================================ +// Cancel-safe body writing types +// ============================================================================ + +impl BodyMode { + /// Extract `(total, written)` from `ContentLength`, panicking on mismatch. + fn expect_content_length(&self) -> (usize, usize) { + match self { + BodyMode::ContentLength(total, written) => (*total, *written), + _ => panic!("wrong body mode: expected ContentLength, got {:?}", self), + } + } + + /// Extract `written` from `ChunkedEncoding`, panicking on mismatch. + fn expect_chunked(&self) -> usize { + match self { + BodyMode::ChunkedEncoding(written) => *written, + _ => panic!("wrong body mode: expected ChunkedEncoding, got {:?}", self), + } + } + + /// Extract `written` from `UntilClose`, panicking on mismatch. + fn expect_until_close(&self) -> usize { + match self { + BodyMode::UntilClose(written) => *written, + _ => panic!("wrong body mode: expected UntilClose, got {:?}", self), + } + } +} + +/// Type alias for the chunked encoding buffer chain +type ChunkedBuf = bytes::buf::Chain, &'static [u8]>; + +enum WriteBuf { + /// Simple bytes buffer + Simple(Bytes), + /// Chained buffer for chunked encoding or other complex writes + Chained(C), +} + +// Implement Buf for WriteBuf to delegate to the inner buffer +impl Buf for WriteBuf { + fn remaining(&self) -> usize { + match self { + WriteBuf::Simple(b) => b.remaining(), + WriteBuf::Chained(c) => c.remaining(), + } + } + + fn chunk(&self) -> &[u8] { + match self { + WriteBuf::Simple(b) => b.chunk(), + WriteBuf::Chained(c) => c.chunk(), + } + } + + fn advance(&mut self, cnt: usize) { + match self { + WriteBuf::Simple(b) => b.advance(cnt), + WriteBuf::Chained(c) => c.advance(cnt), + } + } +} + +enum WriteState { + /// No write in progress + Idle, + /// Writing data (original size, bytes remaining to write) + Writing(usize, WriteBuf), + /// Flushing after write (original size to return) + Flushing(usize), + /// Write complete (bytes written in this task) + Done(usize), + /// Write timed out - cannot be reused + TimedOut, +} + +// Custom Debug implementation since we can't derive it with futures +impl std::fmt::Debug for WriteState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + WriteState::Idle => write!(f, "Idle"), + WriteState::Writing(size, _buf) => { + write!(f, "Writing(size: {})", size) + } + WriteState::Flushing(size) => write!(f, "Flushing(size: {})", size), + WriteState::Done(size) => write!(f, "Done(size: {})", size), + WriteState::TimedOut => write!(f, "TimedOut"), + } + } +} + +enum FinishWriteState { + /// No finish task queued + NotStarted, + /// Finish queued but not started yet + Idle, + /// Writing last chunk marker (for chunked encoding) + WritingLastChunk(WriteBuf), + /// Flushing after writing last chunk + Flushing, + /// Finish complete + Done, +} + +// Custom Debug implementation since WriteBuf doesn't implement Debug +impl std::fmt::Debug for FinishWriteState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + FinishWriteState::NotStarted => write!(f, "NotStarted"), + FinishWriteState::Idle => write!(f, "Idle"), + FinishWriteState::WritingLastChunk(_) => write!(f, "WritingLastChunk"), + FinishWriteState::Flushing => write!(f, "Flushing"), + FinishWriteState::Done => write!(f, "Done"), + } + } +} + +/// Internal state for the cancel-safe body write state machine. +/// +/// Tracks the pending body bytes, write progress +/// (idle → writing → flushing → done), and an optional timeout. +struct SendBodyState { + /// Application bytes queued to be written + pending_bytes: Option, + /// Current write state for cancel-safe operations + write_state: WriteState, + /// Timeout duration for this write task + timeout_duration: Option, + /// Timeout future (only created if write returns Pending) + timeout_fut: Option + Send + Sync>>>, +} + +impl std::fmt::Debug for SendBodyState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SendBodyState") + .field("pending_bytes", &self.pending_bytes) + .field("write_state", &self.write_state) + .field("timeout_duration", &self.timeout_duration) + .field( + "timeout_fut", + &self.timeout_fut.as_ref().map(|_| "Some(Future)"), + ) + .finish() + } +} + +impl SendBodyState { + fn new() -> Self { + SendBodyState { + pending_bytes: None, + write_state: WriteState::Idle, + timeout_duration: None, + timeout_fut: None, + } + } +} + +/// Tracks how response body bytes are framed and written to the wire. +/// +/// Supports both a legacy async API (`write_body` / `finish`) and a cancel-safe +/// task API that can be driven inside a `tokio::select!` loop without losing +/// write progress. pub struct BodyWriter { pub body_mode: BodyMode, + // Boxed to reduce inline size. Only used by the cancel-safe proxy task API. + send_body_state: Box, + send_finish_state: FinishWriteState, +} + +impl Default for BodyWriter { + fn default() -> Self { + Self::new() + } } impl BodyWriter { pub fn new() -> Self { BodyWriter { body_mode: BM::ToSelect, + send_body_state: Box::new(SendBodyState::new()), + send_finish_state: FinishWriteState::NotStarted, } } @@ -1109,6 +1288,543 @@ impl BodyWriter { _ => panic!("wrong body mode: {:?}", self.body_mode), } } + + // ======================================================================== + // Cancel-safe body task API + // ======================================================================== + + #[cfg(test)] + pub fn has_pending_body_task(&self) -> bool { + self.send_body_state.pending_bytes.is_some() + || !matches!( + self.send_body_state.write_state, + WriteState::Idle | WriteState::Done(_) | WriteState::TimedOut + ) + } + + /// Queue application bytes as a body write task with an optional timeout. + /// This is a non-async function that just saves the bytes. + /// Call `write_current_body_task()` to actually perform the write. + /// + /// The timeout, if provided, will be enforced internally across all + /// write attempts, even if the write is cancelled and resumed via `tokio::select!`. + pub fn send_body_task(&mut self, bytes: Bytes, timeout: Option) { + assert!( + matches!( + self.send_body_state.write_state, + WriteState::Idle | WriteState::Done(_) + ), + "send_body_task called while previous task is still in progress: {:?}", + self.send_body_state.write_state + ); + self.send_body_state.pending_bytes = Some(bytes); + self.send_body_state.write_state = WriteState::Idle; + self.send_body_state.timeout_duration = timeout; + self.send_body_state.timeout_fut = None; + } + + /// Writes the current queued body task to the stream. + /// + /// ## Cancel-safety + /// + /// This function can be safely used in a `tokio::select!` loop. + /// Returns `Ok(Some(bytes_written))` when complete, `Ok(None)` if no bytes to write. + pub async fn write_current_body_task(&mut self, stream: &mut S) -> Result> + where + S: AsyncWrite + Unpin + Send, + { + // Use poll_fn to wrap our poll-based implementation + std::future::poll_fn(|cx| self.poll_write_current_body_task(cx, Pin::new(stream))).await + } + + /// Poll-based implementation for writing body tasks. + /// This is the core implementation that maintains state across cancellations. + fn poll_write_current_body_task( + &mut self, + cx: &mut Context<'_>, + stream: Pin<&mut S>, + ) -> Poll>> + where + S: AsyncWrite + Unpin + Send, + { + // Check if already timed out - don't allow reuse + if matches!(self.send_body_state.write_state, WriteState::TimedOut) { + return Poll::Ready(Error::e_explain( + WriteTimedout, + "write task previously timed out", + )); + } + + // Lazy timeout optimization: Poll write first, create timeout only if needed. + // + // This follows the pattern from `pingora_timeout::Timeout` to avoid allocating + // and registering timeout futures when writes complete immediately (the common case). + // + // Fast path: Write completes → return immediately, no timeout future created + // Slow path: Write blocks → lazily create timeout future and poll both + + // First, try the write operation + // Dispatch to the appropriate body mode handler + let result = match self.body_mode { + BM::Complete(_) => Poll::Ready(Ok(None)), + BM::ContentLength(_, _) => self.poll_write_content_length_body_task(cx, stream), + BM::ChunkedEncoding(_) => self.poll_write_chunked_body_task(cx, stream), + BM::UntilClose(_) => self.poll_write_until_close_body_task(cx, stream), + BM::ToSelect => Poll::Ready(Ok(None)), + }; + + // If write completed immediately, return without ever creating/polling timeout + if result.is_ready() { + return result; + } + + // Write returned Pending - lazily create and check timeout if duration is set + if let Some(duration) = self.send_body_state.timeout_duration { + let timeout = self.send_body_state.timeout_fut.get_or_insert_with(|| { + Box::pin(pingora_timeout::sleep(duration)) + as std::pin::Pin + Send + Sync>> + }); + + if timeout.as_mut().poll(cx).is_ready() { + // Timeout fired! Mark state as timed out and clear the timeout future + self.send_body_state.write_state = WriteState::TimedOut; + self.send_body_state.timeout_fut = None; + return Poll::Ready(Error::e_explain( + WriteTimedout, + "writing body task timed out", + )); + } + } + + // Both write and timeout are pending + Poll::Pending + } + + // ======================================================================== + // Cancel-safe finish task API + // ======================================================================== + + #[cfg(test)] + pub fn has_pending_finish_task(&self) -> bool { + !matches!( + self.send_finish_state, + FinishWriteState::NotStarted | FinishWriteState::Done + ) + } + + /// Queue a finish operation as a task. + /// This is a non-async function that just marks the finish as pending. + /// Call `write_current_finish_task()` to actually perform the finish. + /// + /// This API is stateful and cancel-safe - use it when you need to finish + /// the body in a `tokio::select!` loop or other cancellable context. + pub fn send_finish_task(&mut self) { + self.send_finish_state = FinishWriteState::Idle; + } + + /// Async function that performs the current queued finish task on the stream. + /// This function is cancel-safe and can be called in a `tokio::select!` loop. + /// Returns `Ok(Some(bytes_written))` when complete, `Ok(None)` if already complete. + /// + /// This API is stateful - it tracks progress across cancellations and can be + /// safely resumed after being dropped mid-execution. + pub async fn write_current_finish_task(&mut self, stream: &mut S) -> Result> + where + S: AsyncWrite + Unpin + Send, + { + // Use poll_fn to wrap our poll-based implementation + std::future::poll_fn(|cx| self.poll_write_current_finish_task(cx, Pin::new(stream))).await + } + + /// Poll-based implementation for finish tasks. + /// This is the core implementation that maintains state across cancellations. + fn poll_write_current_finish_task( + &mut self, + cx: &mut Context<'_>, + stream: Pin<&mut S>, + ) -> Poll>> + where + S: AsyncWrite + Unpin + Send, + { + // If no finish queued, return None + if matches!( + self.send_finish_state, + FinishWriteState::NotStarted | FinishWriteState::Done + ) { + return Poll::Ready(Ok(None)); + } + + // Route to body-mode-specific implementation + match self.body_mode { + BM::Complete(_) => Poll::Ready(Ok(None)), + BM::ContentLength(_, _) => self.poll_finish_content_length_task(cx, stream), + BM::ChunkedEncoding(_) => self.poll_finish_chunked_task(cx, stream), + BM::UntilClose(_) => self.poll_finish_until_close_task(cx, stream), + BM::ToSelect => Poll::Ready(Ok(None)), + } + } + + /// Finish content-length body - just validates and updates state. + /// No I/O needed since body write tasks already flushed after the last write. + fn poll_finish_content_length_task( + &mut self, + _cx: &mut Context<'_>, + _stream: Pin<&mut S>, + ) -> Poll>> + where + S: AsyncWrite + Unpin + Send, + { + let written = match self.body_mode { + BM::ContentLength(total, w) => { + if w < total { + self.send_finish_state = FinishWriteState::Done; + return Poll::Ready(Error::e_explain( + PREMATURE_BODY_END, + format!("Content-length: {total} bytes written: {w}"), + )); + } + w + } + _ => panic!("wrong body mode: {:?}", self.body_mode), + }; + + // All bytes written - just update state to Complete + self.body_mode = BM::Complete(written); + self.send_finish_state = FinishWriteState::Done; + Poll::Ready(Ok(Some(written))) + } + + /// Poll-based helper to finish chunked encoding body + fn poll_finish_chunked_task( + &mut self, + cx: &mut Context<'_>, + mut stream: Pin<&mut S>, + ) -> Poll>> + where + S: AsyncWrite + Unpin + Send, + { + let written = match self.body_mode { + BM::ChunkedEncoding(w) => w, + _ => panic!("wrong body mode: {:?}", self.body_mode), + }; + + loop { + match &mut self.send_finish_state { + FinishWriteState::Idle => { + // Start writing last chunk marker "0\r\n\r\n" + let buf = WriteBuf::Simple(Bytes::from_static(&LAST_CHUNK[..])); + self.send_finish_state = FinishWriteState::WritingLastChunk(buf); + } + FinishWriteState::WritingLastChunk(buf) => { + // Poll write_vec_all - write until all bytes are written + ready!(poll_write_vec_all_buf(cx, stream.as_mut(), buf)) + .map_err(|e| Error::because(WriteError, "while writing last chunk", e))?; + + // All bytes written, move to flushing state + self.send_finish_state = FinishWriteState::Flushing; + } + FinishWriteState::Flushing => { + // Poll flush + ready!(stream.as_mut().poll_flush(cx)) + .map_err(|e| Error::because(WriteError, "flushing after last chunk", e))?; + + // Flush complete! Update body_mode and mark done + self.body_mode = BM::Complete(written); + self.send_finish_state = FinishWriteState::Done; + return Poll::Ready(Ok(Some(written))); + } + FinishWriteState::Done => { + unreachable!( + "Done state should have been handled in poll_write_current_finish_task" + ) + } + FinishWriteState::NotStarted => { + unreachable!("NotStarted state should have been handled in poll_write_current_finish_task") + } + } + } + } + + /// Finish until-close body - just updates state. + /// No I/O needed since body write tasks already flushed after each write. + fn poll_finish_until_close_task( + &mut self, + _cx: &mut Context<'_>, + _stream: Pin<&mut S>, + ) -> Poll>> + where + S: AsyncWrite + Unpin + Send, + { + let written = match self.body_mode { + BM::UntilClose(w) => w, + _ => panic!("wrong body mode: {:?}", self.body_mode), + }; + + // Just update state to Complete + self.body_mode = BM::Complete(written); + self.send_finish_state = FinishWriteState::Done; + Poll::Ready(Ok(Some(written))) + } + + // ======================================================================== + // Internal helpers + // ======================================================================== + + /// Internal helper to poll a body task that writes in content-length mode + /// and flushes at end. + fn poll_write_content_length_body_task( + &mut self, + cx: &mut Context<'_>, + mut stream: Pin<&mut S>, + ) -> Poll>> + where + S: AsyncWrite + Unpin + Send, + { + // Move to Writing state if we're Idle + if matches!(self.send_body_state.write_state, WriteState::Idle) { + if let Some(mut bytes) = self.send_body_state.pending_bytes.take() { + let (total, written) = self.body_mode.expect_content_length(); + + // Check if we've already written everything + if written >= total { + self.send_body_state.write_state = WriteState::Done(0); + return Poll::Ready(Ok(None)); + } + + let original_size = bytes.len(); + let remaining = total - written; + + // Truncate bytes if they exceed content-length + if original_size > remaining { + warn!( + "Trying to write {} bytes over content-length: {}, truncating to {}", + original_size, total, remaining + ); + bytes.truncate(remaining); + } + + let bytes_to_write = bytes.len(); + self.send_body_state.write_state = + WriteState::Writing(bytes_to_write, WriteBuf::Simple(bytes)); + } else { + self.send_body_state.write_state = WriteState::Done(0); + return Poll::Ready(Ok(None)); + } + } + + // Handle Writing state - do the write, transition to Flushing or Done + if let WriteState::Writing(size, ref mut buf) = &mut self.send_body_state.write_state { + let bytes_written = *size; + + // Attempt write + match ready!(poll_write_all_buf(cx, stream.as_mut(), buf)) { + Ok(()) => { + // Write completed - update body_mode to track bytes written + let (total, written) = self.body_mode.expect_content_length(); + self.body_mode = BM::ContentLength(total, written + bytes_written); + + if written + bytes_written >= total { + // All content-length bytes written, flush needed + self.send_body_state.write_state = WriteState::Flushing(bytes_written); + } else { + // More bytes to come, no flush needed + self.send_body_state.write_state = WriteState::Done(bytes_written); + } + } + Err(e) => { + return Poll::Ready(Error::e_because(WriteError, "while writing body", e)) + } + } + } + + // Handle Flushing state - do the flush, transition to Done + if let WriteState::Flushing(size) = self.send_body_state.write_state { + let bytes_written = size; + + // Attempt flush + match ready!(stream.poll_flush(cx)) { + Ok(()) => { + // Flush completed - transition to Done + self.send_body_state.write_state = WriteState::Done(bytes_written); + } + Err(e) => return Poll::Ready(Error::e_because(WriteError, "flushing body", e)), + } + } + + // Return based on final state + match self.send_body_state.write_state { + WriteState::Done(size) => { + self.send_body_state.timeout_fut = None; + Poll::Ready(Ok(Some(size))) + } + WriteState::TimedOut => Poll::Ready(Error::e_explain( + WriteTimedout, + "write task previously timed out", + )), + WriteState::Writing(..) | WriteState::Flushing(..) => { + unreachable!("Writing/Flushing states should have been handled above or returned Pending via ready!") + } + WriteState::Idle => { + unreachable!("Idle state should have been handled in setup") + } + } + } + + /// Poll-based implementation for chunked encoding mode + fn poll_write_chunked_body_task( + &mut self, + cx: &mut Context<'_>, + mut stream: Pin<&mut S>, + ) -> Poll>> + where + S: AsyncWrite + Unpin + Send, + { + // Move to Writing state if we're Idle + if matches!(self.send_body_state.write_state, WriteState::Idle) { + if let Some(bytes) = self.send_body_state.pending_bytes.take() { + let application_bytes_size = bytes.len(); + + // Format the chunk: size\r\ndata\r\n + let chunk_size_header = format!("{:X}\r\n", application_bytes_size); + let output_buf = Bytes::from(chunk_size_header) + .chain(bytes) + .chain(&b"\r\n"[..]); + + // Store the chained buffer directly to avoid copying + self.send_body_state.write_state = + WriteState::Writing(application_bytes_size, WriteBuf::Chained(output_buf)); + } else { + self.send_body_state.write_state = WriteState::Done(0); + return Poll::Ready(Ok(None)); + } + } + + // Handle Writing state - do the write using vectored I/O, transition to Flushing + if let WriteState::Writing(size, ref mut buf) = &mut self.send_body_state.write_state { + let bytes_written = *size; + + // Attempt vectored write for chained buffer (chunk size + data + CRLF) + match ready!(poll_write_vec_all_buf(cx, stream.as_mut(), buf)) { + Ok(()) => { + // Write completed - update body_mode with application bytes (not wire bytes) + let written = self.body_mode.expect_chunked(); + self.body_mode = BM::ChunkedEncoding(written + bytes_written); + + // Chunked encoding always flushes + self.send_body_state.write_state = WriteState::Flushing(bytes_written); + } + Err(e) => { + return Poll::Ready(Error::e_because(WriteError, "while writing body", e)) + } + } + } + + // Handle Flushing state - do the flush, transition to Done + if let WriteState::Flushing(size) = self.send_body_state.write_state { + let bytes_written = size; + + // Attempt flush + match ready!(stream.poll_flush(cx)) { + Ok(()) => { + // Flush completed - transition to Done + self.send_body_state.write_state = WriteState::Done(bytes_written); + } + Err(e) => return Poll::Ready(Error::e_because(WriteError, "flushing body", e)), + } + } + + // Return based on final state + match self.send_body_state.write_state { + WriteState::Done(size) => { + self.send_body_state.timeout_fut = None; + Poll::Ready(Ok(Some(size))) + } + WriteState::TimedOut => Poll::Ready(Error::e_explain( + WriteTimedout, + "write task previously timed out", + )), + WriteState::Writing(..) | WriteState::Flushing(..) => { + unreachable!("Writing/Flushing states should have been handled above or returned Pending via ready!") + } + WriteState::Idle => { + unreachable!("Idle state should have been handled in setup") + } + } + } + + /// Poll-based implementation for UntilClose (close-delimited) body mode + fn poll_write_until_close_body_task( + &mut self, + cx: &mut Context<'_>, + mut stream: Pin<&mut S>, + ) -> Poll>> + where + S: AsyncWrite + Unpin + Send, + { + // Move to Writing state if we're Idle + if matches!(self.send_body_state.write_state, WriteState::Idle) { + if let Some(bytes) = self.send_body_state.pending_bytes.take() { + let original_size = bytes.len(); + self.send_body_state.write_state = + WriteState::Writing(original_size, WriteBuf::Simple(bytes)); + } else { + self.send_body_state.write_state = WriteState::Done(0); + return Poll::Ready(Ok(None)); + } + } + + // Handle Writing state - do the write, transition to Flushing + if let WriteState::Writing(size, ref mut buf) = &mut self.send_body_state.write_state { + let bytes_written = *size; + + // Attempt write + match ready!(poll_write_all_buf(cx, stream.as_mut(), buf)) { + Ok(()) => { + // Write completed - update body_mode to track bytes written + let written = self.body_mode.expect_until_close(); + self.body_mode = BM::UntilClose(written + bytes_written); + + // Close-delimited mode always flushes + self.send_body_state.write_state = WriteState::Flushing(bytes_written); + } + Err(e) => { + return Poll::Ready(Error::e_because(WriteError, "while writing body", e)) + } + } + } + + // Handle Flushing state - do the flush, transition to Done + if let WriteState::Flushing(size) = self.send_body_state.write_state { + let bytes_written = size; + + // Attempt flush + match ready!(stream.poll_flush(cx)) { + Ok(()) => { + // Flush completed - transition to Done + self.send_body_state.write_state = WriteState::Done(bytes_written); + } + Err(e) => return Poll::Ready(Error::e_because(WriteError, "flushing body", e)), + } + } + + // Return based on final state + match self.send_body_state.write_state { + WriteState::Done(size) => { + self.send_body_state.timeout_fut = None; + Poll::Ready(Ok(Some(size))) + } + WriteState::TimedOut => Poll::Ready(Error::e_explain( + WriteTimedout, + "write task previously timed out", + )), + WriteState::Writing(..) | WriteState::Flushing(..) => { + unreachable!("Writing/Flushing states should have been handled above or returned Pending via ready!") + } + WriteState::Idle => { + unreachable!("Idle state should have been handled in setup") + } + } + } } #[cfg(test)] @@ -1717,33 +2433,41 @@ mod tests { let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); assert_eq!(res, BufRef::new(0, 0)); assert_eq!(body_reader.body_state, ParseState::Chunked(0, 0, 2, 2)); - let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); - assert_eq!(res, BufRef::new(3, 1)); // input1 concat input2 - assert_eq!(&input2[1..2], body_reader.get_body(&res)); - assert_eq!(body_reader.body_state, ParseState::Chunked(1, 6, 11, 0)); - let res = body_reader.read_body(&mut mock_io).await.unwrap(); - assert_eq!(res, None); - assert_eq!(body_reader.body_state, ParseState::Complete(1)); - assert_eq!(body_reader.get_body_overread(), None); + let _res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); } #[tokio::test] - async fn read_with_body_partial_head_terminal_crlf() { + async fn read_with_body_partial_head_chunk_incomplete() { init_log(); let input1 = b"1\r"; - let input2 = b"\na\r\n0\r\n\r"; - let input3 = b"\n"; - let mut mock_io = Builder::new() - .read(&input1[..]) - .read(&input2[..]) - .read(&input3[..]) - .build(); + let mut mock_io = Builder::new().read(&input1[..]).build(); let mut body_reader = BodyReader::new(false); body_reader.init_chunked(b""); let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); assert_eq!(res, BufRef::new(0, 0)); assert_eq!(body_reader.body_state, ParseState::Chunked(0, 0, 2, 2)); - let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + let res = body_reader.read_body(&mut mock_io).await; + assert!(res.is_err()); + assert_eq!(body_reader.body_state, ParseState::Done(0)); + } + + #[tokio::test] + async fn read_with_body_partial_head_terminal_crlf() { + init_log(); + let input1 = b"1\r"; + let input2 = b"\na\r\n0\r\n\r"; + let input3 = b"\n"; + let mut mock_io = Builder::new() + .read(&input1[..]) + .read(&input2[..]) + .read(&input3[..]) + .build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::Chunked(0, 0, 2, 2)); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); assert_eq!(res, BufRef::new(3, 1)); // input1 concat input2 assert_eq!(&input2[1..2], body_reader.get_body(&res)); assert_eq!(body_reader.body_state, ParseState::Chunked(1, 6, 10, 0)); @@ -1925,21 +2649,6 @@ mod tests { assert_eq!(body_reader.get_body_overread(), Some(&b"abc"[..])); } - #[tokio::test] - async fn read_with_body_partial_head_chunk_incomplete() { - init_log(); - let input1 = b"1\r"; - let mut mock_io = Builder::new().read(&input1[..]).build(); - let mut body_reader = BodyReader::new(false); - body_reader.init_chunked(b""); - let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); - assert_eq!(res, BufRef::new(0, 0)); - assert_eq!(body_reader.body_state, ParseState::Chunked(0, 0, 2, 2)); - let res = body_reader.read_body(&mut mock_io).await; - assert!(res.is_err()); - assert_eq!(body_reader.body_state, ParseState::Done(0)); - } - #[tokio::test] async fn read_with_body_trailers() { init_log(); @@ -2319,7 +3028,7 @@ mod tests { } #[tokio::test] - async fn write_body_http10() { + async fn write_body_until_close() { init_log(); let data = b"a"; let mut mock_io = Builder::new().write(&data[..]).write(&data[..]).build(); @@ -2345,3 +3054,799 @@ mod tests { assert_eq!(body_writer.body_mode, BodyMode::Complete(2)); } } + +#[cfg(test)] +mod test_body_task_api { + use super::*; + use crate::protocols::http::v1::test_util::FlushTrackingMock; + use tokio_test::io::Builder; + + // Cancel-safety tests use tokio::select! to race a short sleep against a mock + // I/O wait, simulating cancellation. We use #[tokio::test(start_paused = true)] + // on these tests so that tokio auto-advances time deterministically rather than + // relying on wall-clock timing. + + fn init_log() { + let _ = env_logger::builder().is_test(true).try_init(); + } + + #[tokio::test] + async fn test_has_pending_body_task() { + init_log(); + let data = b"test data"; + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data.len()); + + // Initially should have no pending task + assert!(!body_writer.has_pending_body_task()); + + // After queuing bytes, should have pending task + body_writer.send_body_task(Bytes::from_static(data), None); + assert!(body_writer.has_pending_body_task()); + } + + #[tokio::test(start_paused = true)] + async fn cancel_safe_content_length_write() { + init_log(); + let data = b"Hello, World!"; + + // Create a mock stream that will block to allow cancellation + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(100)) + .write(data) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data.len()); + + // Queue the bytes to write + body_writer.send_body_task(Bytes::from_static(data), None); + + // Use tokio::select! loop - keep looping until write completes + let mut cancel_count = 0; + let mut total_bytes_written = 0; + + loop { + // Break if no pending writes + if !body_writer.has_pending_body_task() { + break; + } + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => { + // Timeout fires first, cancelling the write + cancel_count += 1; + } + result = body_writer.write_current_body_task(&mut mock_io) => { + // Write completed + assert!(result.is_ok(), "Write should succeed"); + if let Ok(Some(n)) = result { + total_bytes_written += n; + } + } + } + } + + assert!( + cancel_count > 0, + "At least one cancellation should have occurred" + ); + assert_eq!( + total_bytes_written, + data.len(), + "Should have written all application bytes" + ); + assert_eq!( + body_writer.body_mode, + BodyMode::ContentLength(data.len(), data.len()) + ); + + // Now test finish() in a select loop as well + let mut mock_io_finish = Builder::new().build(); + + loop { + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(5)) => { + // Allow cancellation attempts + } + result = body_writer.finish(&mut mock_io_finish) => { + assert!(result.is_ok()); + break; + } + } + } + + assert_eq!(body_writer.body_mode, BodyMode::Complete(data.len())); + } + + #[tokio::test(start_paused = true)] + async fn cancel_safe_chunked_write() { + init_log(); + let data = b"abcdefghij"; + let expected_output = b"A\r\nabcdefghij\r\n"; + + // Mock stream that blocks to allow cancellation + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(100)) + .write(expected_output) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_chunked(); + + // Queue bytes + body_writer.send_body_task(Bytes::from_static(data), None); + + // Use select loop - keep looping until write completes + let mut cancel_count = 0; + let mut total_bytes_written = 0; + + loop { + if !body_writer.has_pending_body_task() { + break; + } + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => { + cancel_count += 1; + } + result = body_writer.write_current_body_task(&mut mock_io) => { + assert!(result.is_ok()); + if let Ok(Some(n)) = result { + total_bytes_written += n; + } + } + } + } + + assert!(cancel_count > 0, "Should have cancelled at least once"); + assert_eq!( + total_bytes_written, + data.len(), + "Should have written all application bytes" + ); + assert_eq!(body_writer.body_mode, BodyMode::ChunkedEncoding(data.len())); + + // Test finish() with select loop - must write terminating chunk + let mut mock_io_finish = Builder::new() + .wait(std::time::Duration::from_millis(50)) + .write(&LAST_CHUNK[..]) // Expect 0\r\n\r\n + .build(); + + loop { + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(5)) => {} + result = body_writer.finish(&mut mock_io_finish) => { + assert!(result.is_ok()); + break; + } + } + } + + assert_eq!(body_writer.body_mode, BodyMode::Complete(data.len())); + } + + #[tokio::test(start_paused = true)] + async fn cancel_safe_until_close_write() { + init_log(); + let data = b"test data"; + + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(100)) + .write(data) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_close_delimited(); + + body_writer.send_body_task(Bytes::from_static(data), None); + + // Use select loop - keep looping until write completes + let mut cancel_count = 0; + let mut total_bytes_written = 0; + + loop { + if !body_writer.has_pending_body_task() { + break; + } + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => { + cancel_count += 1; + } + result = body_writer.write_current_body_task(&mut mock_io) => { + assert!(result.is_ok()); + if let Ok(Some(n)) = result { + total_bytes_written += n; + } + } + } + } + + assert!(cancel_count > 0); + assert_eq!( + total_bytes_written, + data.len(), + "Should have written all application bytes" + ); + + // Test finish() with select loop + let mut mock_io_finish = Builder::new().build(); + + loop { + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(5)) => {} + result = body_writer.finish(&mut mock_io_finish) => { + assert!(result.is_ok()); + break; + } + } + } + } + + #[tokio::test(start_paused = true)] + async fn cancel_safe_multiple_cancellations() { + init_log(); + let data = b"Long test data that requires multiple writes"; + + // Create a mock that blocks multiple times + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(50)) + .write(&data[..15]) + .wait(std::time::Duration::from_millis(50)) + .write(&data[15..30]) + .wait(std::time::Duration::from_millis(50)) + .write(&data[30..]) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data.len()); + body_writer.send_body_task(Bytes::from_static(data), None); + + // Loop until write completes, allowing cancellations + let mut cancel_count = 0; + let mut total_bytes_written = 0; + + loop { + if !body_writer.has_pending_body_task() { + break; + } + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(5)) => { + cancel_count += 1; + } + result = body_writer.write_current_body_task(&mut mock_io) => { + assert!(result.is_ok()); + if let Ok(Some(n)) = result { + total_bytes_written += n; + } + } + } + } + + assert!(cancel_count >= 2, "Should have multiple cancellations"); + assert_eq!( + total_bytes_written, + data.len(), + "Should have written all application bytes" + ); + + // Test finish with select loop + let mut mock_io_finish = Builder::new().build(); + + loop { + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(5)) => {} + result = body_writer.finish(&mut mock_io_finish) => { + assert!(result.is_ok()); + break; + } + } + } + } + + #[tokio::test(start_paused = true)] + async fn cancel_safe_partial_writes() { + init_log(); + let data = b"12345678901234567890"; // 20 bytes + + // Simulate partial writes with blocking + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(50)) + .write(&data[..7]) + .wait(std::time::Duration::from_millis(50)) + .write(&data[7..14]) + .wait(std::time::Duration::from_millis(50)) + .write(&data[14..]) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data.len()); + body_writer.send_body_task(Bytes::from_static(data), None); + + let mut cancel_count = 0; + let mut total_bytes_written = 0; + + loop { + if !body_writer.has_pending_body_task() { + break; + } + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => { + cancel_count += 1; + } + result = body_writer.write_current_body_task(&mut mock_io) => { + assert!(result.is_ok()); + if let Ok(Some(n)) = result { + total_bytes_written += n; + } + } + } + } + + assert!(cancel_count > 0); + assert_eq!( + total_bytes_written, + data.len(), + "Should have written all application bytes" + ); + + // Test finish in select loop + let mut mock_io_finish = Builder::new() + .wait(std::time::Duration::from_millis(30)) + .build(); + + loop { + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(5)) => {} + result = body_writer.finish(&mut mock_io_finish) => { + assert!(result.is_ok(), "Finish should succeed after cancel-safe writes"); + break; + } + } + } + } + + #[tokio::test] + async fn test_task_write_timeout() { + init_log(); + let data = b"test data"; + + // Create a mock that blocks forever + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_secs(1000)) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data.len()); + + // Queue the task with a timeout + body_writer.send_body_task( + Bytes::from_static(data), + Some(std::time::Duration::from_millis(50)), + ); + + // The write should timeout + let result = body_writer.write_current_body_task(&mut mock_io).await; + assert!(result.is_err(), "Write should timeout"); + + // Check that it's a timeout error + if let Err(e) = result { + assert_eq!(e.etype(), &WriteTimedout); + } + } + + // Even if the user's select! cancels the write, the internal timeout + // should continue counting across cancellations. + #[tokio::test] + async fn test_task_timeout_persists_across_cancellations() { + init_log(); + let data = b"test data"; + + // Create a mock that blocks for a while + // Since timeout is 100ms and this waits 200ms, the write should never happen + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(200)) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data.len()); + + // Queue the task with a 100ms timeout + body_writer.send_body_task( + Bytes::from_static(data), + Some(std::time::Duration::from_millis(100)), + ); + + let mut attempts = 0; + let mut timedout = false; + + // Try to write in a loop, but cancel early each time + // The timeout should still fire even though we're cancelling + loop { + if !body_writer.has_pending_body_task() { + break; + } + + attempts += 1; + + tokio::select! { + // Cancel after just 10ms each time + _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => { + // Cancelled by our select, continue looping + continue; + } + result = body_writer.write_current_body_task(&mut mock_io) => { + match result { + Ok(_) => { + // Write succeeded before timeout + break; + } + Err(e) if e.etype() == &WriteTimedout => { + // Timeout fired! + timedout = true; + break; + } + Err(e) => { + panic!("Unexpected error: {:?}", e); + } + } + } + } + } + + assert!(timedout, "Timeout should have fired despite cancellations"); + assert!( + attempts >= 5, + "Should have had multiple attempts before timeout" + ); + } + + #[tokio::test] + async fn test_task_write_succeeds_within_timeout() { + init_log(); + let data = b"Hello, World!"; + + // Create a mock that completes quickly + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(20)) + .write(data) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data.len()); + + // Queue with a generous timeout + body_writer.send_body_task( + Bytes::from_static(data), + Some(std::time::Duration::from_millis(500)), + ); + + // Write should succeed + let result = body_writer.write_current_body_task(&mut mock_io).await; + assert!(result.is_ok(), "Write should succeed: {:?}", result); + assert_eq!(result.unwrap(), Some(data.len())); + } + + #[tokio::test] + async fn test_task_write_no_timeout() { + init_log(); + let data = b"test data"; + + // Create a mock that takes a bit of time but eventually succeeds + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(100)) + .write(data) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data.len()); + + // Queue without timeout + body_writer.send_body_task(Bytes::from_static(data), None); + + // Write should eventually succeed + let result = body_writer.write_current_body_task(&mut mock_io).await; + assert!(result.is_ok(), "Write should succeed without timeout"); + assert_eq!(result.unwrap(), Some(data.len())); + } + + #[tokio::test] + async fn test_task_chunked_write_timeout() { + init_log(); + let data = b"chunked data"; + + // Create a mock that blocks + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_secs(1000)) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_chunked(); + + // Queue with short timeout + body_writer.send_body_task( + Bytes::from_static(data), + Some(std::time::Duration::from_millis(50)), + ); + + // Should timeout + let result = body_writer.write_current_body_task(&mut mock_io).await; + assert!(result.is_err()); + if let Err(e) = result { + assert_eq!(e.etype(), &WriteTimedout); + } + } + + #[tokio::test] + async fn test_task_timeout_reset_on_new_task() { + init_log(); + let data1 = b"first"; + let data2 = b"second"; + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data1.len() + data2.len()); + + // Queue first task with short timeout + body_writer.send_body_task( + Bytes::from_static(data1), + Some(std::time::Duration::from_millis(50)), + ); + + // Wait a bit but don't let it timeout yet + tokio::time::sleep(std::time::Duration::from_millis(30)).await; + + // Queue a new task with a longer timeout + // This should reset/replace the timeout + body_writer.send_body_task( + Bytes::from_static(data2), + Some(std::time::Duration::from_millis(500)), + ); + + // Create a mock that takes some time + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(100)) + .write(data2) + .build(); + + // The second write should succeed with its own timeout + let result = body_writer.write_current_body_task(&mut mock_io).await; + assert!( + result.is_ok(), + "Second task should succeed with new timeout" + ); + } + + #[tokio::test] + async fn test_task_timeout_with_partial_writes() { + init_log(); + let data1 = b"first"; + let data2 = b"second"; + let data3 = b"third"; + + // Mock that writes data1 quickly, data2 with delay, data3 blocks forever + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(10)) + .write(data1) + .wait(std::time::Duration::from_millis(40)) + .write(data2) + .wait(std::time::Duration::from_secs(1000)) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data1.len() + data2.len() + data3.len()); + + let mut total_written = 0; + + // First write - should succeed within timeout + body_writer.send_body_task( + Bytes::from_static(data1), + Some(std::time::Duration::from_millis(100)), + ); + let result = body_writer.write_current_body_task(&mut mock_io).await; + assert!(result.is_ok()); + total_written += result.unwrap().unwrap(); + + // Second write - should succeed within timeout + body_writer.send_body_task( + Bytes::from_static(data2), + Some(std::time::Duration::from_millis(100)), + ); + let result = body_writer.write_current_body_task(&mut mock_io).await; + assert!(result.is_ok()); + total_written += result.unwrap().unwrap(); + + // Third write - should timeout + body_writer.send_body_task( + Bytes::from_static(data3), + Some(std::time::Duration::from_millis(50)), + ); + let result = body_writer.write_current_body_task(&mut mock_io).await; + assert!(result.is_err()); + assert_eq!(result.unwrap_err().etype(), &WriteTimedout); + + // We should have written data1 and data2 but not data3 + assert_eq!(total_written, data1.len() + data2.len()); + assert!( + total_written < data1.len() + data2.len() + data3.len(), + "Should not have written all data" + ); + } + + // Cancel-safe finish task for chunked encoding: send_finish_task() queues + // the terminating chunk, write_current_finish_task() writes it and can be + // cancelled and resumed in a select! loop. + // Verifies that the finish flushes the stream exactly once. + #[tokio::test(start_paused = true)] + async fn cancel_safe_finish_task_chunked() { + init_log(); + + let data = Bytes::from("hello"); + let expected_chunk = b"5\r\nhello\r\n"; + + let mock_io = Builder::new().write(expected_chunk).build(); + let (mut flush_mock, flush_count) = FlushTrackingMock::new(mock_io); + + let mut body_writer = BodyWriter::new(); + body_writer.init_chunked(); + + // Write body data via task API + body_writer.send_body_task(data, None); + body_writer + .write_current_body_task(&mut flush_mock) + .await + .unwrap(); + + // Chunked body writes always flush after each chunk + assert_eq!( + FlushTrackingMock::flush_count(&flush_count), + 1, + "Chunked body data write should flush once" + ); + + // Queue the finish task + body_writer.send_finish_task(); + assert!(body_writer.has_pending_finish_task()); + + // Write the finish in a select! loop with cancellations + let mock_io_finish = Builder::new() + .wait(std::time::Duration::from_millis(100)) + .write(b"0\r\n\r\n") + .build(); + let (mut flush_mock_finish, flush_count_finish) = FlushTrackingMock::new(mock_io_finish); + + let mut cancel_count = 0; + + loop { + if !body_writer.has_pending_finish_task() { + break; + } + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => { + cancel_count += 1; + } + result = body_writer.write_current_finish_task(&mut flush_mock_finish) => { + assert!(result.is_ok()); + break; + } + } + } + + assert!(cancel_count > 0, "Should have cancelled at least once"); + assert!(matches!(body_writer.body_mode, BodyMode::Complete(_))); + assert_eq!( + FlushTrackingMock::flush_count(&flush_count_finish), + 1, + "Chunked finish should flush exactly once" + ); + } + + // Finish task for content-length is a no-op (no terminating chunk needed), + // but it should still transition body_mode to Complete. + // Verifies that no flush occurs (content-length finish has no I/O). + #[tokio::test] + async fn finish_task_content_length() { + init_log(); + + let data = b"hello"; + let mock_io = Builder::new().write(data).build(); + let (mut flush_mock, flush_count) = FlushTrackingMock::new(mock_io); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data.len()); + + body_writer.send_body_task(Bytes::from_static(data), None); + body_writer + .write_current_body_task(&mut flush_mock) + .await + .unwrap(); + + // Content-length body write flushes when all bytes are written + assert_eq!( + FlushTrackingMock::flush_count(&flush_count), + 1, + "Content-length body write should flush once (all bytes written)" + ); + + body_writer.send_finish_task(); + let mock_io_finish = Builder::new().build(); + let (mut flush_mock_finish, flush_count_finish) = FlushTrackingMock::new(mock_io_finish); + let result = body_writer + .write_current_finish_task(&mut flush_mock_finish) + .await; + assert!(result.is_ok()); + assert!(matches!(body_writer.body_mode, BodyMode::Complete(_))); + assert_eq!( + FlushTrackingMock::flush_count(&flush_count_finish), + 0, + "Content-length finish should not flush (no I/O needed)" + ); + } + + // Verifies that body_mode byte tracking is correct when writing + // content-length body in multiple chunks. Each intermediate chunk + // does not trigger a flush; the body_mode must still accumulate + // bytes correctly so that finish_task succeeds. + #[tokio::test] + async fn content_length_body_mode_tracks_across_chunks() { + init_log(); + + let chunk1 = b"Hello"; + let chunk2 = b", World!"; + let total_len = chunk1.len() + chunk2.len(); // 13 + + // Mock expects both writes; the final write triggers a flush internally + let mut mock_io = Builder::new().write(chunk1).write(chunk2).build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(total_len); + + // Write first chunk (intermediate, no flush expected) + body_writer.send_body_task(Bytes::from_static(chunk1), None); + let result = body_writer + .write_current_body_task(&mut mock_io) + .await + .unwrap(); + assert_eq!(result, Some(chunk1.len())); + assert!( + !body_writer.finished(), + "Should not be finished after first chunk" + ); + + // Verify body_mode tracks the bytes from the first chunk + assert!( + matches!(body_writer.body_mode, BodyMode::ContentLength(total, written) + if total == total_len && written == chunk1.len()), + "body_mode should reflect bytes written so far, got: {:?}", + body_writer.body_mode + ); + + // Write second chunk (final, completes content-length) + body_writer.send_body_task(Bytes::from_static(chunk2), None); + let result = body_writer + .write_current_body_task(&mut mock_io) + .await + .unwrap(); + assert_eq!(result, Some(chunk2.len())); + assert!( + body_writer.finished(), + "Should be finished after all bytes written" + ); + + // Finish should succeed since all content-length bytes were written + body_writer.send_finish_task(); + let mut mock_io_finish = Builder::new().build(); + let result = body_writer + .write_current_finish_task(&mut mock_io_finish) + .await; + assert!( + result.is_ok(), + "finish_task should succeed when all content-length bytes written" + ); + assert!(matches!(body_writer.body_mode, BodyMode::Complete(_))); + } +} diff --git a/pingora-core/src/protocols/http/v1/client.rs b/pingora-core/src/protocols/http/v1/client.rs index 5f9e4610..a60aad1f 100644 --- a/pingora-core/src/protocols/http/v1/client.rs +++ b/pingora-core/src/protocols/http/v1/client.rs @@ -625,13 +625,6 @@ impl HttpSession { // follow https://datatracker.ietf.org/doc/html/rfc9112#section-6.3 let preread_body = self.preread_body.as_ref().unwrap().get(&self.buf[..]); - if let Some(req) = self.request_written.as_ref() { - if req.method == http::method::Method::HEAD { - self.body_reader.init_content_length(0, preread_body); - return; - } - } - let upgraded = if let Some(code) = self.get_status() { match code.as_u16() { 101 => self.is_upgrade_req(), @@ -650,6 +643,13 @@ impl HttpSession { false }; + if let Some(req) = self.request_written.as_ref() { + if req.method == http::method::Method::HEAD { + self.body_reader.init_content_length(0, preread_body); + return; + } + } + if upgraded { self.body_reader.init_close_delimited(preread_body); self.close_delimited_resp = true; @@ -2224,6 +2224,283 @@ hello"; http_stream.respect_keepalive(); assert!(!http_stream.will_keepalive()); } + + #[tokio::test] + async fn read_informational_head_request() { + init_log(); + // HEAD request that receives 100 Continue followed by 200 OK + let wire = b"HEAD / HTTP/1.1\r\n\r\n"; + let input1 = b"HTTP/1.1 100 Continue\r\n\r\n"; + let input2 = b"HTTP/1.1 200 OK\r\nContent-Length: 100\r\n\r\n"; + + let mock_io = Builder::new() + .write(&wire[..]) + .read(&input1[..]) + .read(&input2[..]) + .build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + + // Write HEAD request + let new_request = RequestHeader::build("HEAD", b"/", None).unwrap(); + http_stream + .write_request_header(Box::new(new_request)) + .await + .unwrap(); + + // Read 100 Continue + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 100); + assert!(!eob, "100 Continue for HEAD should not signal end of body"); + } + _ => { + panic!("task should be informational header") + } + } + + // Read final 200 OK + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 200); + assert!(eob, "HEAD 200 response should signal end of body"); + } + _ => { + panic!("task should be final header") + } + } + + // Body reader should be Complete(0) for HEAD + assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(0)); + } + + #[tokio::test] + async fn read_informational_multiple_head_request() { + init_log(); + // HEAD request that receives 100 Continue, 103 Early Hints, then 200 OK + let wire = b"HEAD / HTTP/1.1\r\n\r\n"; + let input = b"HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 103 Early Hints\r\n\r\nHTTP/1.1 200 OK\r\nContent-Length: 50\r\n\r\n"; + + let mock_io = Builder::new().write(&wire[..]).read(&input[..]).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + + let new_request = RequestHeader::build("HEAD", b"/", None).unwrap(); + http_stream + .write_request_header(Box::new(new_request)) + .await + .unwrap(); + + // Read 100 Continue + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 100); + assert!(!eob, "100 Continue for HEAD should not signal end of body"); + } + _ => { + panic!("task should be 100 header") + } + } + + // Read 103 Early Hints + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 103); + assert!( + !eob, + "103 Early Hints for HEAD should not signal end of body" + ); + } + _ => { + panic!("task should be 103 header") + } + } + + // Read 200 OK — end of body + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 200); + assert!(eob, "HEAD 200 response should signal end of body"); + } + _ => { + panic!("task should be final header") + } + } + + assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(0)); + } + + #[tokio::test] + async fn read_basic_head() { + init_log(); + // Basic HEAD + 200 + let wire = b"HEAD / HTTP/1.1\r\n\r\n"; + let input = b"HTTP/1.1 200 OK\r\nContent-Length: 100\r\n\r\n"; + + let mock_io = Builder::new().write(&wire[..]).read(&input[..]).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + + let new_request = RequestHeader::build("HEAD", b"/", None).unwrap(); + http_stream + .write_request_header(Box::new(new_request)) + .await + .unwrap(); + + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 200); + assert!(eob, "HEAD 200 should be end of body"); + } + _ => { + panic!("task should be header") + } + } + + assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(0)); + + // Keepalive should work for a properly-framed HEAD response + http_stream.respect_keepalive(); + assert!(http_stream.will_keepalive()); + } + + #[tokio::test] + async fn read_head_informational_keepalive() { + init_log(); + // HEAD + 100 Continue + 200 OK, then verify keepalive is preserved. + let wire = b"HEAD / HTTP/1.1\r\n\r\n"; + let input1 = b"HTTP/1.1 100 Continue\r\n\r\n"; + let input2 = b"HTTP/1.1 200 OK\r\nContent-Length: 100\r\n\r\n"; + + let mock_io = Builder::new() + .write(&wire[..]) + .read(&input1[..]) + .read(&input2[..]) + .build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + + let new_request = RequestHeader::build("HEAD", b"/", None).unwrap(); + http_stream + .write_request_header(Box::new(new_request)) + .await + .unwrap(); + + // 100 Continue + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 100); + assert!(!eob); + } + _ => panic!("task should be informational header"), + } + + // 200 OK + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 200); + assert!(eob); + } + _ => panic!("task should be final header"), + } + + assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(0)); + + // Keepalive must still work after the 100 + 200 sequence + http_stream.respect_keepalive(); + assert!(http_stream.will_keepalive()); + } + + #[tokio::test] + async fn read_head_204() { + init_log(); + // HEAD + 204 No Content + let wire = b"HEAD / HTTP/1.1\r\n\r\n"; + let input = b"HTTP/1.1 204 No Content\r\n\r\n"; + + let mock_io = Builder::new().write(&wire[..]).read(&input[..]).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + + let new_request = RequestHeader::build("HEAD", b"/", None).unwrap(); + http_stream + .write_request_header(Box::new(new_request)) + .await + .unwrap(); + + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 204); + assert!(eob, "HEAD 204 should be end of body"); + } + _ => panic!("task should be header"), + } + + assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(0)); + } + + #[tokio::test] + async fn read_head_304() { + init_log(); + // HEAD + 304 Not Modified + let wire = b"HEAD / HTTP/1.1\r\n\r\n"; + let input = b"HTTP/1.1 304 Not Modified\r\nContent-Length: 100\r\n\r\n"; + + let mock_io = Builder::new().write(&wire[..]).read(&input[..]).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + + let new_request = RequestHeader::build("HEAD", b"/", None).unwrap(); + http_stream + .write_request_header(Box::new(new_request)) + .await + .unwrap(); + + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 304); + assert!(eob, "HEAD 304 should be end of body"); + } + _ => panic!("task should be header"), + } + + assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(0)); + } + + #[tokio::test] + async fn read_head_101_non_upgrade() { + init_log(); + // HEAD + 101 where the request is not an upgrade request. + // Contrived, but verifies the new code path: 101 check fires first, + // is_upgrade_req() returns false, then HEAD check fires. + let wire = b"HEAD / HTTP/1.1\r\n\r\n"; + let input = b"HTTP/1.1 101 Switching Protocols\r\n\r\n"; + + let mock_io = Builder::new().write(&wire[..]).read(&input[..]).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + + let new_request = RequestHeader::build("HEAD", b"/", None).unwrap(); + http_stream + .write_request_header(Box::new(new_request)) + .await + .unwrap(); + + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 101); + // HEAD without Upgrade headers → not an upgrade, body is "done" + assert!(eob, "HEAD 101 (non-upgrade) should be end of body"); + } + _ => panic!("task should be header"), + } + + assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(0)); + } } #[cfg(test)] diff --git a/pingora-core/src/protocols/http/v1/header.rs b/pingora-core/src/protocols/http/v1/header.rs new file mode 100644 index 00000000..b6abdb71 --- /dev/null +++ b/pingora-core/src/protocols/http/v1/header.rs @@ -0,0 +1,459 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Cancel-safe header writing for HTTP/1.x + +use bytes::Bytes; +use pingora_error::{Error, ErrorType::*, Result}; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; +use tokio::io::AsyncWrite; + +use crate::protocols::l4::stream::async_write_vec::poll_write_all_buf; + +enum HeaderWriteState { + /// No write in progress + Idle, + /// Writing header bytes (original size, buffer) + Writing(usize, Bytes), + /// Flushing after write (original size to return) + Flushing(usize), + /// Write complete + Done, + /// Write timed out - cannot be reused + TimedOut, +} + +// Custom Debug implementation +impl std::fmt::Debug for HeaderWriteState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + HeaderWriteState::Idle => write!(f, "Idle"), + HeaderWriteState::Writing(size, _) => write!(f, "Writing(size: {})", size), + HeaderWriteState::Flushing(size) => write!(f, "Flushing(size: {})", size), + HeaderWriteState::Done => write!(f, "Done"), + HeaderWriteState::TimedOut => write!(f, "TimedOut"), + } + } +} + +/// Internal state for the cancel-safe header write state machine. +/// +/// Tracks the pending header bytes, write progress (idle → writing → flushing → done), +/// and an optional timeout that is lazily created on the first `Pending` poll. +struct SendHeaderState { + /// Serialized header bytes ready to be written + pending_header: Option, + /// Whether to flush after writing + should_flush: bool, + /// Current write state + write_state: HeaderWriteState, + /// Timeout duration for this write task + timeout_duration: Option, + /// Timeout future (only created if write returns Pending) + timeout_fut: Option + Send + Sync>>>, +} + +// Custom Debug implementation since timeout_fut doesn't implement Debug +impl std::fmt::Debug for SendHeaderState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SendHeaderState") + .field("pending_header", &self.pending_header) + .field("should_flush", &self.should_flush) + .field("write_state", &self.write_state) + .field("timeout_duration", &self.timeout_duration) + .field( + "timeout_fut", + &self.timeout_fut.as_ref().map(|_| "Some(Future)"), + ) + .finish() + } +} + +impl SendHeaderState { + fn new() -> Self { + SendHeaderState { + pending_header: None, + should_flush: false, + write_state: HeaderWriteState::Idle, + timeout_duration: None, + timeout_fut: None, + } + } +} + +/// Cancel-safe header writer for HTTP/1.x response headers. +/// +/// This writer allows response headers to be written to a downstream connection +/// inside a `tokio::select!` loop without losing progress. If the write is +/// cancelled (e.g. because another branch of the select fires first), the +/// partially-written state is preserved and will be resumed on the next call to +/// [`write_current_header_task`](Self::write_current_header_task). +/// +/// ## Usage +/// +/// 1. Call [`send_header_task`](Self::send_header_task) with pre-serialized +/// header bytes, a flush flag, and an optional timeout. +/// 2. Await [`write_current_header_task`](Self::write_current_header_task) +/// (possibly inside `tokio::select!`). The method returns `Ok(bytes_written)` +/// on success. +/// +/// A timeout, if set, is enforced *across* cancellations — the clock keeps +/// ticking even when the future is dropped and re-polled. +pub struct HeaderWriter { + // Boxed to reduce inline size. Only used by the cancel-safe proxy task API. + send_header_state: Box, +} + +impl Default for HeaderWriter { + fn default() -> Self { + Self::new() + } +} + +impl HeaderWriter { + pub fn new() -> Self { + HeaderWriter { + send_header_state: Box::new(SendHeaderState::new()), + } + } + + #[cfg(test)] + pub fn has_pending_header_task(&self) -> bool { + self.send_header_state.pending_header.is_some() + || !matches!( + self.send_header_state.write_state, + HeaderWriteState::Idle | HeaderWriteState::Done | HeaderWriteState::TimedOut + ) + } + + /// Queue serialized header bytes as a write task with an optional timeout. + /// This is a non-async function that just saves the bytes. + /// Call [`write_current_header_task`](Self::write_current_header_task) to actually perform the write. + pub fn send_header_task( + &mut self, + header_bytes: Bytes, + should_flush: bool, + timeout: Option, + ) { + assert!( + matches!( + self.send_header_state.write_state, + HeaderWriteState::Idle | HeaderWriteState::Done + ), + "send_header_task called while previous task is still in progress: {:?}", + self.send_header_state.write_state + ); + self.send_header_state.pending_header = Some(header_bytes); + self.send_header_state.should_flush = should_flush; + self.send_header_state.write_state = HeaderWriteState::Idle; + self.send_header_state.timeout_duration = timeout; + self.send_header_state.timeout_fut = None; + } + + /// Async function that writes the current queued header task to the stream. + /// This function is cancel-safe and can be called in a `tokio::select!` loop. + /// Returns `Ok(bytes_written)` when complete, `Ok(0)` if no bytes to write. + pub async fn write_current_header_task(&mut self, stream: &mut S) -> Result + where + S: AsyncWrite + Unpin + Send, + { + std::future::poll_fn(|cx| self.poll_write_current_header_task(cx, Pin::new(stream))).await + } + + /// Poll-based implementation for writing the current header task. + fn poll_write_current_header_task( + &mut self, + cx: &mut Context<'_>, + stream: Pin<&mut S>, + ) -> Poll> + where + S: AsyncWrite + Unpin + Send, + { + // Check if already timed out - don't allow reuse + if matches!( + self.send_header_state.write_state, + HeaderWriteState::TimedOut + ) { + return Poll::Ready(Error::e_explain( + WriteTimedout, + "header write task previously timed out", + )); + } + + // First, try the write operation + match self.poll_do_write_header_and_flush(cx, stream) { + Poll::Ready(Ok(size)) => { + // Write completed! Clear timeout and return + if matches!(self.send_header_state.write_state, HeaderWriteState::Done) { + self.send_header_state.timeout_fut = None; + } + return Poll::Ready(Ok(size)); + } + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => { + // Write is pending - now check timeout + } + } + + // Lazy timeout optimization: Polls write first, creates timeout only if needed. + // This follows the pattern from `pingora_timeout::Timeout` to avoid allocating + // timeout futures when writes complete immediately (the common case). + if let Some(duration) = self.send_header_state.timeout_duration { + let timeout = self.send_header_state.timeout_fut.get_or_insert_with(|| { + Box::pin(pingora_timeout::sleep(duration)) + as std::pin::Pin + Send + Sync>> + }); + + if timeout.as_mut().poll(cx).is_ready() { + // Timeout fired! + self.send_header_state.write_state = HeaderWriteState::TimedOut; + self.send_header_state.timeout_fut = None; + return Poll::Ready(Error::e_explain( + WriteTimedout, + "writing header task timed out", + )); + } + } + + // Both write and timeout are pending + Poll::Pending + } + + /// Poll-based helper to write header bytes and optionally flush. + /// Handles state transitions explicitly. + fn poll_do_write_header_and_flush( + &mut self, + cx: &mut Context<'_>, + mut stream: Pin<&mut S>, + ) -> Poll> + where + S: AsyncWrite + Unpin + Send, + { + // Handle Idle state - take pending header and transition to Writing + if matches!(self.send_header_state.write_state, HeaderWriteState::Idle) { + if let Some(header_bytes) = self.send_header_state.pending_header.take() { + let size = header_bytes.len(); + self.send_header_state.write_state = HeaderWriteState::Writing(size, header_bytes); + } else { + // No pending header + self.send_header_state.write_state = HeaderWriteState::Done; + return Poll::Ready(Ok(0)); + } + } + + // Write if in Writing state + if let HeaderWriteState::Writing(original_size, ref mut buf) = + self.send_header_state.write_state + { + let size = original_size; + ready!(poll_write_all_buf(cx, stream.as_mut(), buf)) + .map_err(|e| Error::because(WriteError, "writing response header", e))?; + + // Write complete - transition to next state + if self.send_header_state.should_flush { + self.send_header_state.write_state = HeaderWriteState::Flushing(size); + } else { + self.send_header_state.write_state = HeaderWriteState::Done; + return Poll::Ready(Ok(size)); + } + } + + // Handle the state after writing (or if we started in a non-Writing state) + match self.send_header_state.write_state { + HeaderWriteState::Flushing(size) => { + ready!(stream.as_mut().poll_flush(cx)) + .map_err(|e| Error::because(WriteError, "flushing response header", e))?; + // Flush complete - transition to Done + self.send_header_state.write_state = HeaderWriteState::Done; + Poll::Ready(Ok(size)) + } + HeaderWriteState::Done => Poll::Ready(Ok(0)), + HeaderWriteState::TimedOut => Poll::Ready(Error::e_explain( + WriteTimedout, + "header write task previously timed out", + )), + HeaderWriteState::Idle => { + unreachable!("Idle state should have been handled above") + } + HeaderWriteState::Writing(..) => { + unreachable!("Writing state should have been handled above") + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::protocols::http::v1::test_util::FlushTrackingMock; + use tokio_test::io::Builder; + + fn init_log() { + let _ = env_logger::builder().is_test(true).try_init(); + } + + #[tokio::test] + async fn test_simple_header_write_no_flush() { + init_log(); + let header_data = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\n"; + + let mock_io = Builder::new().write(header_data).build(); + let (mut flush_mock, flush_count) = FlushTrackingMock::new(mock_io); + + let mut header_writer = HeaderWriter::new(); + header_writer.send_header_task(Bytes::from_static(header_data), false, None); + + let result = header_writer + .write_current_header_task(&mut flush_mock) + .await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), header_data.len()); + assert_eq!( + FlushTrackingMock::flush_count(&flush_count), + 0, + "should_flush=false should not flush" + ); + } + + #[tokio::test] + async fn test_header_write_with_flush() { + init_log(); + let header_data = b"HTTP/1.1 200 OK\r\n\r\n"; + + let mock_io = Builder::new().write(header_data).build(); + let (mut flush_mock, flush_count) = FlushTrackingMock::new(mock_io); + + let mut header_writer = HeaderWriter::new(); + header_writer.send_header_task(Bytes::from_static(header_data), true, None); + + let result = header_writer + .write_current_header_task(&mut flush_mock) + .await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), header_data.len()); + assert_eq!( + FlushTrackingMock::flush_count(&flush_count), + 1, + "should_flush=true should flush exactly once" + ); + } + + // Uses start_paused for deterministic timer-based cancellation in select! + #[tokio::test(start_paused = true)] + async fn test_cancel_safe_header_write() { + init_log(); + let header_data = b"HTTP/1.1 200 OK\r\nServer: pingora\r\n\r\n"; + + // Mock that blocks to allow cancellation + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(100)) + .write(header_data) + .build(); + + let mut header_writer = HeaderWriter::new(); + header_writer.send_header_task(Bytes::from_static(header_data), false, None); + + let mut cancel_count = 0; + + loop { + if !header_writer.has_pending_header_task() { + break; + } + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => { + cancel_count += 1; + } + result = header_writer.write_current_header_task(&mut mock_io) => { + assert!(result.is_ok()); + assert_eq!(result.unwrap(), header_data.len()); + break; + } + } + } + + assert!(cancel_count > 0, "Should have cancelled at least once"); + } + + #[tokio::test] + async fn test_header_write_timeout() { + init_log(); + let header_data = b"HTTP/1.1 200 OK\r\n\r\n"; + + // Mock that blocks forever + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_secs(1000)) + .build(); + + let mut header_writer = HeaderWriter::new(); + header_writer.send_header_task( + Bytes::from_static(header_data), + false, + Some(std::time::Duration::from_millis(50)), + ); + + let result = header_writer.write_current_header_task(&mut mock_io).await; + assert!(result.is_err()); + assert_eq!(result.unwrap_err().etype(), &WriteTimedout); + } + + #[tokio::test] + async fn test_header_write_timeout_persists() { + init_log(); + let header_data = b"HTTP/1.1 200 OK\r\n\r\n"; + + // Mock that blocks for a while + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(200)) + .build(); + + let mut header_writer = HeaderWriter::new(); + header_writer.send_header_task( + Bytes::from_static(header_data), + false, + Some(std::time::Duration::from_millis(100)), + ); + + let mut attempts = 0; + let mut timedout = false; + + loop { + if !header_writer.has_pending_header_task() { + break; + } + + attempts += 1; + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => { + continue; + } + result = header_writer.write_current_header_task(&mut mock_io) => { + match result { + Ok(_) => break, + Err(e) if e.etype() == &WriteTimedout => { + timedout = true; + break; + } + Err(e) => panic!("Unexpected error: {:?}", e), + } + } + } + } + + assert!(timedout, "Timeout should have fired"); + assert!(attempts >= 5, "Should have had multiple attempts"); + } +} diff --git a/pingora-core/src/protocols/http/v1/mod.rs b/pingora-core/src/protocols/http/v1/mod.rs index 19602491..6f085a70 100644 --- a/pingora-core/src/protocols/http/v1/mod.rs +++ b/pingora-core/src/protocols/http/v1/mod.rs @@ -17,4 +17,110 @@ pub(crate) mod body; pub mod client; pub mod common; +pub(crate) mod header; pub mod server; + +/// Test utilities shared across HTTP/1.x unit tests +#[cfg(test)] +pub(crate) mod test_util { + use std::pin::Pin; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + use std::task::{Context, Poll}; + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + use tokio_test::io::Mock; + + /// A wrapper around [`Mock`] that counts flush calls. + /// + /// `tokio_test::io::Mock`'s `poll_flush` always returns `Ready(Ok(()))`, + /// so we can't detect flush calls via mock alone. This wrapper counts them. + #[derive(Debug)] + pub(crate) struct FlushTrackingMock { + inner: Mock, + flush_count: Arc, + } + + impl FlushTrackingMock { + pub(crate) fn new(mock: Mock) -> (Self, Arc) { + let flush_count = Arc::new(AtomicUsize::new(0)); + ( + FlushTrackingMock { + inner: mock, + flush_count: flush_count.clone(), + }, + flush_count, + ) + } + + pub(crate) fn flush_count(counter: &Arc) -> usize { + counter.load(Ordering::Relaxed) + } + } + + impl AsyncRead for FlushTrackingMock { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.get_mut().inner).poll_read(cx, buf) + } + } + + impl AsyncWrite for FlushTrackingMock { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.get_mut().inner).poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + let result = Pin::new(&mut this.inner).poll_flush(cx); + if let Poll::Ready(Ok(())) = &result { + this.flush_count.fetch_add(1, Ordering::Relaxed); + } + result + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.get_mut().inner).poll_shutdown(cx) + } + } + + // Implement IO-required traits so FlushTrackingMock can be used as Box + // in HttpSession tests (server.rs). + use crate::protocols::{ + raw_connect::ProxyDigest, GetProxyDigest, GetSocketDigest, GetTimingDigest, Peek, Shutdown, + SocketDigest, Ssl, TimingDigest, UniqueID, UniqueIDType, + }; + + #[async_trait::async_trait] + impl Shutdown for FlushTrackingMock { + async fn shutdown(&mut self) -> () {} + } + impl UniqueID for FlushTrackingMock { + fn id(&self) -> UniqueIDType { + 0 + } + } + impl Ssl for FlushTrackingMock {} + impl GetTimingDigest for FlushTrackingMock { + fn get_timing_digest(&self) -> Vec> { + vec![] + } + } + impl GetProxyDigest for FlushTrackingMock { + fn get_proxy_digest(&self) -> Option> { + None + } + } + impl GetSocketDigest for FlushTrackingMock { + fn get_socket_digest(&self) -> Option> { + None + } + } + impl Peek for FlushTrackingMock {} +} diff --git a/pingora-core/src/protocols/http/v1/server.rs b/pingora-core/src/protocols/http/v1/server.rs index 03ebf81f..9144c6e5 100644 --- a/pingora-core/src/protocols/http/v1/server.rs +++ b/pingora-core/src/protocols/http/v1/server.rs @@ -28,15 +28,48 @@ use pingora_http::{IntoCaseHeaderName, RequestHeader, ResponseHeader}; use pingora_timeout::timeout; use regex::bytes::Regex; use std::any::Any; +use std::collections::VecDeque; use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use super::body::{BodyReader, BodyWriter}; use super::common::*; +use super::header::HeaderWriter; use crate::protocols::http::{body_buffer::FixedBuffer, date, HttpTask}; use crate::protocols::{Digest, SocketAddr, Stream}; use crate::utils::{BufRef, KVRef}; +/// Tracks which writer is currently processing a task. +/// +/// This enables resuming writes after cancellation. Each variant stores the +/// minimal data needed for cleanup after write completes. +#[derive(Debug)] +enum ProxyTaskWriter { + /// Currently writing a header task. + /// Stores: (header for `response_written`, end_stream flag) + WritingHeader(Box, bool), + /// Currently writing a body task (`Body` or `UpgradedBody`). + /// Stores: (end_stream flag) + WritingBody(bool), + /// Currently finishing the body (writing last chunk + flush). + FinishingBody, +} + +/// State for the cancel-safe proxy task write API. +#[derive(Default)] +struct ProxyTaskState { + /// Lazily initialized — `HeaderWriter::new()` heap-allocates. + header_writer: Option, + tasks: VecDeque, + current_writer: Option, +} + +impl ProxyTaskState { + fn header_writer(&mut self) -> &mut HeaderWriter { + self.header_writer.get_or_insert_with(HeaderWriter::new) + } +} + /// The HTTP 1.x server session pub struct HttpSession { underlying_stream: Stream, @@ -52,6 +85,8 @@ pub struct HttpSession { body_reader: BodyReader, /// A state machine to track how to write the response body body_writer: BodyWriter, + /// Cancel-safe proxy task state. + proxy_task_state: ProxyTaskState, /// An internal buffer to buf multiple body writes to reduce the underlying syscalls body_write_buf: BytesMut, /// Track how many application (not on the wire) body bytes already sent @@ -91,6 +126,16 @@ pub struct HttpSession { /// Set by [`HttpPersistentSettings::apply_to_session`](crate::apps::HttpPersistentSettings::apply_to_session), /// consumed by the proxy layer via [`take_connection_user_context`](Self::take_connection_user_context). connection_user_context: Option>, + /// Whether the client has closed the TCP connection (sent FIN / read returned 0). + half_closed: bool, + /// When true (default), a client close after the request body is surfaced as a + /// `ConnectionClosed` error so the proxy aborts immediately. When false, the + /// close is tolerated and `read_body_or_idle` stays pending so the proxy can + /// finish delivering the upstream response (RFC 9112 Section 9.6). + abort_on_close: bool, + /// Whether the cancel-safe proxy task API is enabled for this session. + /// Defaults to false. Can be enabled via [`set_proxy_tasks_enabled`](Self::set_proxy_tasks_enabled). + proxy_tasks_enabled: bool, } impl HttpSession { @@ -113,6 +158,7 @@ impl HttpSession { preread_body: None, body_reader: BodyReader::new(false), body_writer: BodyWriter::new(), + proxy_task_state: ProxyTaskState::default(), body_write_buf: BytesMut::new(), keepalive_timeout: KeepaliveStatus::Off, update_resp_headers: true, @@ -132,6 +178,9 @@ impl HttpSession { close_on_response_before_downstream_finish: true, keepalive_reuses_remaining: None, connection_user_context: None, + half_closed: false, + abort_on_close: true, + proxy_tasks_enabled: false, } } @@ -501,101 +550,12 @@ impl HttpSession { /// Write the response header to the client. /// This function can be called more than once to send 1xx informational headers excluding 101. pub async fn write_response_header(&mut self, mut header: Box) -> Result<()> { - if header.status.is_informational() && self.ignore_info_resp(header.status.into()) { - debug!("ignoring informational headers"); + // Prepare header (handle upgrades, set headers, initialize body writer, serialize to bytes) + let Some((write_buf, flush)) = self.prepare_response_header(&mut header)? else { + // Header already sent or should be ignored return Ok(()); - } - - if let Some(resp) = self.response_written.as_ref() { - if !resp.status.is_informational() || self.upgraded { - warn!("Respond header is already sent, cannot send again"); - return Ok(()); - } - } - - // if body unfinished, or request header was not finished reading - if self.close_on_response_before_downstream_finish - && (self.request_header.is_none() || !self.is_body_done()) - { - debug!("set connection close before downstream finish"); - self.set_keepalive(None); - } - - // no need to add these headers to 1xx responses - if !header.status.is_informational() && self.update_resp_headers { - /* update headers */ - header.insert_header(header::DATE, date::get_cached_date())?; - - // TODO: make these lazy static - let connection_value = if self.will_keepalive() { - "keep-alive" - } else { - "close" - }; - header.insert_header(header::CONNECTION, connection_value)?; - } - - if header.status == 101 { - // make sure the connection is closed at the end when 101/upgrade is used - self.set_keepalive(None); - } - - // Allow informational header (excluding 101) to pass through without affecting the state - // of the request - if header.status == 101 || !header.status.is_informational() { - // reset request body to done for incomplete upgrade handshakes - if let Some(upgrade_ok) = self.is_upgrade(&header) { - if upgrade_ok { - debug!("ok upgrade handshake"); - // For ws we use HTTP1_0 do_read_body_until_closed - // - // On ws close the initiator sends a close frame and - // then waits for a response from the peer, once it receives - // a response it closes the conn. After receiving a - // control frame indicating the connection should be closed, - // a peer discards any further data received. - // https://www.rfc-editor.org/rfc/rfc6455#section-1.4 - self.upgraded = true; - // Now that the upgrade was successful, we need to change - // how we interpret the rest of the body as pass-through. - if self.body_reader.need_init() { - self.init_body_reader(); - } else { - // already initialized - // immediately start reading the rest of the body as upgraded - // (in practice most upgraded requests shouldn't have any body) - // - // TODO: https://datatracker.ietf.org/doc/html/rfc9110#name-upgrade - // the most spec-compliant behavior is to switch interpretation - // after sending the former body, - // we immediately switch interpretation to match nginx - self.body_reader.convert_to_close_delimited(); - } - } else { - // this was a request that requested Upgrade, - // but upstream did not comply - debug!("bad upgrade handshake!"); - // continue to read body as-is, this is now just a regular request - } - } - self.init_body_writer(&header); - } - - // Defense-in-depth: if response body is close-delimited, mark session - // as un-reusable - if self.body_writer.is_close_delimited() { - self.set_keepalive(None); - } - - // Don't have to flush response with content length because it is less - // likely to be real time communication. So do flush when - // 1.1xx response: client needs to see it before the rest of response - // 2.No content length: the response could be generated in real time - let flush = header.status.is_informational() - || header.headers.get(header::CONTENT_LENGTH).is_none(); + }; - let mut write_buf = BytesMut::with_capacity(INIT_HEADER_BUF_SIZE); - http_resp_header_to_buf(&header, &mut write_buf).unwrap(); match self.underlying_stream.write_all(&write_buf).await { Ok(()) => { // flush the stream if 1xx header or there is no response body @@ -606,7 +566,6 @@ impl HttpSession { .or_err(WriteError, "flushing response header")?; } self.response_written = Some(header); - self.body_bytes_sent += write_buf.len(); Ok(()) } Err(e) => Error::e_because(WriteError, "writing response header", e), @@ -750,6 +709,117 @@ impl HttpSession { } } + /// Prepare response header for writing: handle upgrades, set headers, initialize body writer. + /// This contains all the synchronous logic that should happen before writing the header. + /// Returns Ok(Some((bytes, should_flush))) if the header should be written, Ok(None) if should skip. + fn prepare_response_header( + &mut self, + header: &mut ResponseHeader, + ) -> Result> { + // Check if we should ignore informational responses + if header.status.is_informational() && self.ignore_info_resp(header.status.into()) { + debug!("ignoring informational headers"); + return Ok(None); + } + + // Check if we already sent a response header + if let Some(ref resp) = self.response_written { + if !resp.status.is_informational() || self.upgraded { + warn!("Respond header is already sent, cannot send again"); + return Ok(None); + } + } + + // if body unfinished, or request header was not finished reading + if self.close_on_response_before_downstream_finish + && (self.request_header.is_none() || !self.is_body_done()) + { + debug!("set connection close before downstream finish"); + self.set_keepalive(None); + } + + // no need to add these headers to 1xx responses + if !header.status.is_informational() && self.update_resp_headers { + /* update headers */ + header.insert_header(header::DATE, date::get_cached_date())?; + + // TODO: make these lazy static + let connection_value = if self.will_keepalive() { + "keep-alive" + } else { + "close" + }; + header.insert_header(header::CONNECTION, connection_value)?; + } + + if header.status == 101 { + // make sure the connection is closed at the end when 101/upgrade is used + self.set_keepalive(None); + } + + // Allow informational header (excluding 101) to pass through without affecting the state + // of the request + if header.status == 101 || !header.status.is_informational() { + // reset request body to done for incomplete upgrade handshakes + if let Some(upgrade_ok) = self.is_upgrade(header) { + if upgrade_ok { + debug!("ok upgrade handshake"); + // For ws we use HTTP1_0 do_read_body_until_closed + // + // On ws close the initiator sends a close frame and + // then waits for a response from the peer, once it receives + // a response it closes the conn. After receiving a + // control frame indicating the connection should be closed, + // a peer discards any further data received. + // https://www.rfc-editor.org/rfc/rfc6455#section-1.4 + self.upgraded = true; + // Now that the upgrade was successful, we need to change + // how we interpret the rest of the body as pass-through. + if self.body_reader.need_init() { + self.init_body_reader(); + } else { + // already initialized + // immediately start reading the rest of the body as upgraded + // (in practice most upgraded requests shouldn't have any body) + // + // TODO: https://datatracker.ietf.org/doc/html/rfc9110#name-upgrade + // the most spec-compliant behavior is to switch interpretation + // after sending the former body, + // we immediately switch interpretation to match nginx + self.body_reader.convert_to_close_delimited(); + } + } else { + // this was a request that requested Upgrade, + // but upstream did not comply + debug!("bad upgrade handshake!"); + // continue to read body as-is, this is now just a regular request + } + } + self.init_body_writer(header); + } + + // Defense-in-depth: if response body is close-delimited, mark session + // as un-reusable + if self.body_writer.is_close_delimited() { + self.set_keepalive(None); + } + + // Serialize header to bytes + let mut write_buf = BytesMut::with_capacity(INIT_HEADER_BUF_SIZE); + http_resp_header_to_buf(header, &mut write_buf) + .map_err(|_| Error::explain(WriteError, "serializing response header"))?; + + // Determine if we should flush + // Don't have to flush response with content length because it is less + // likely to be real time communication. So do flush when + // 1. 1xx response: client needs to see it before the rest of response + // 2. No content length: the response could be generated in real time + let should_flush = header.status.is_informational() + || header.headers.get(header::CONTENT_LENGTH).is_none(); + + Ok(Some((write_buf.freeze(), should_flush))) + } + fn init_body_writer(&mut self, header: &ResponseHeader) { use http::StatusCode; /* the following responses don't have body 204, 304, and HEAD */ @@ -805,6 +875,16 @@ impl HttpSession { } } + /// Whether the cancel-safe proxy task API is enabled for this session. + pub fn proxy_tasks_enabled(&self) -> bool { + self.proxy_tasks_enabled + } + + /// Enable or disable the cancel-safe proxy task API for this session. + pub fn set_proxy_tasks_enabled(&mut self, enabled: bool) { + self.proxy_tasks_enabled = enabled; + } + async fn do_write_body_buf(&mut self) -> Result> { // Don't flush empty chunks, they are considered end of body for chunks if self.body_write_buf.is_empty() { @@ -961,19 +1041,54 @@ impl HttpSession { /// This function will return body bytes (same as [`Self::read_body_bytes()`]), but after /// the client body finishes (`Ok(None)` is returned), calling this function again will block /// forever, same as [`Self::idle()`]. + /// + /// By default (`abort_on_close = true`), if the client closes the connection + /// (sends TCP FIN, i.e. `read == 0`) after the request body is complete, a + /// `ConnectionClosed` error is returned. + /// + /// When `abort_on_close` is **disabled**, the close is tolerated: the future stays + /// pending so the proxy can finish delivering the upstream response via the write + /// path (per RFC 9112 Section 9.6). A true disconnect (RST) will be caught later + /// when the response write fails. + /// + /// Note that this marks the connection as half-closed if FIN is detected. If this function + /// is called after the connection is already marked half-closed and `abort_on_close` is + /// **disabled**, then it will pend forever. pub async fn read_body_or_idle(&mut self, no_body_expected: bool) -> Result> { if no_body_expected || self.is_body_done() { + if self.half_closed { + if self.abort_on_close { + return Error::e_explain( + ConnectionClosed, + if self.response_written.is_none() { + "Prematurely before response header is sent" + } else { + "Prematurely before response body is complete" + }, + ); + } + return std::future::pending().await; + } // XXX: account for upgraded body reader change, if the read half split from the write half let read = self.idle().await?; if read == 0 { - Error::e_explain( - ConnectionClosed, - if self.response_written.is_none() { - "Prematurely before response header is sent" - } else { - "Prematurely before response body is complete" - }, - ) + self.half_closed = true; + self.set_keepalive(None); + if self.abort_on_close { + Error::e_explain( + ConnectionClosed, + if self.response_written.is_none() { + "Prematurely before response header is sent" + } else { + "Prematurely before response body is complete" + }, + ) + } else { + debug!("downstream closed (FIN), keeping write side open"); + // If the connection is fully closed, writing the response side + // will fail. + std::future::pending().await + } } else { Error::e_explain(ConnectError, "Sent data after end of body") } @@ -982,6 +1097,11 @@ impl HttpSession { } } + /// Whether the client has half-closed the TCP connection. + pub fn is_half_closed(&self) -> bool { + self.half_closed + } + /// Return the raw bytes of the request header. pub fn get_headers_raw_bytes(&self) -> Bytes { self.raw_header.as_ref().unwrap().get_bytes(&self.buf) @@ -1079,6 +1199,18 @@ impl HttpSession { self.close_on_response_before_downstream_finish = close; } + /// Controls behaviour when the client closes the connection after the request body. + /// + /// When **enabled** (default), a client close is returned as a `ConnectionClosed` + /// error so the proxy aborts immediately. + /// + /// When **disabled**, `read_body_or_idle` stays pending on a client close so the + /// proxy can finish delivering the upstream response (RFC 9112 Section 9.6). A true + /// disconnect (RST) will surface later when the response write fails. + pub fn set_abort_on_close(&mut self, abort: bool) { + self.abort_on_close = abort; + } + /// Return the [Digest] of the connection. pub fn digest(&self) -> &Digest { &self.digest @@ -1259,6 +1391,152 @@ impl HttpSession { Ok(end_stream || self.body_writer.finished()) } + /// Queue a proxy task for cancel-safe writing with the current write_timeout. + /// The task will be written when `write_proxy_tasks()` is called. + /// + /// A write canceled mid-operation can be resumed via `write_proxy_tasks()`. + pub fn send_proxy_task(&mut self, task: HttpTask) { + self.proxy_task_state.tasks.push_back(task); + } + + /// Check if there are pending proxy tasks queued for writing. + pub fn has_pending_proxy_tasks(&self) -> bool { + self.proxy_task_state.current_writer.is_some() || !self.proxy_task_state.tasks.is_empty() + } + + /// Write all queued proxy tasks (response `HttpTask`s from `send_proxy_task`) + /// in a cancel-safe manner. + /// + /// If cancelled mid-write, the next call will resume the in-progress write. + /// + /// Returns `Ok(true)` if this was the end of the response stream. + // Leverages the cancel-safe `HeaderWriter` and `BodyWriter` primitives. + // TODO: we can do the same for the non-cancel-safe APIs. + pub async fn write_proxy_tasks(&mut self) -> Result { + let mut end_stream = false; + + // TODO: buffer body data like response_duplex_vec + loop { + // - Resume any in-progress write + if let Some(ref writer_state) = self.proxy_task_state.current_writer { + match writer_state { + ProxyTaskWriter::WritingHeader(_, _) => { + let _bytes_written = self + .proxy_task_state + .header_writer() + .write_current_header_task(&mut self.underlying_stream) + .await + .map_err(|e| e.into_down())?; + } + ProxyTaskWriter::WritingBody(_) => { + let written = self + .body_writer + .write_current_body_task(&mut self.underlying_stream) + .await + .map_err(|e| e.into_down())?; + if let Some(n) = written { + self.body_bytes_sent += n; + } + } + ProxyTaskWriter::FinishingBody => { + self.body_writer + .write_current_finish_task(&mut self.underlying_stream) + .await + .map_err(|e| e.into_down())?; + } + } + + match self + .proxy_task_state + .current_writer + .take() + .expect("writer state present") + { + ProxyTaskWriter::WritingHeader(header, end) => { + self.response_written = Some(header); + end_stream = end; + } + ProxyTaskWriter::WritingBody(end) => { + end_stream = end; + } + ProxyTaskWriter::FinishingBody => { + end_stream = true; + self.maybe_force_close_body_reader(); + break; // fine to break after finish, no tasks should be queued after + } + } + continue; + } + + // - Send tasks, set state. + // Pop next task + let Some(task) = self.proxy_task_state.tasks.pop_front() else { + if end_stream { + self.body_writer.send_finish_task(); + self.proxy_task_state.current_writer = Some(ProxyTaskWriter::FinishingBody); + continue; + } + break; + }; + + match task { + HttpTask::Header(mut header, end) => { + let Some((write_buf, should_flush)) = + self.prepare_response_header(&mut header)? + else { + end_stream = end; + continue; + }; + // header only responses will want to flush + let flush = should_flush || self.body_writer.finished(); + self.proxy_task_state + .header_writer() + .send_header_task(write_buf, flush, None); + self.proxy_task_state.current_writer = + Some(ProxyTaskWriter::WritingHeader(header, end)); + } + HttpTask::Body(ref data, end) => { + if self.upgraded { + panic!("Unexpected Body task received on upgraded downstream session"); + } + if let Some(d) = data.as_ref() { + if !d.is_empty() { + let body_timeout = self.write_timeout(d.len()); + self.body_writer.send_body_task(d.clone(), body_timeout); + self.proxy_task_state.current_writer = + Some(ProxyTaskWriter::WritingBody(end)); + continue; + } + } + end_stream = end; + } + HttpTask::UpgradedBody(ref data, end) => { + if !self.upgraded { + panic!("Unexpected UpgradedBody task received on un-upgraded downstream session"); + } + if let Some(d) = data.as_ref() { + if !d.is_empty() { + let body_timeout = self.write_timeout(d.len()); + self.body_writer.send_body_task(d.clone(), body_timeout); + self.proxy_task_state.current_writer = + Some(ProxyTaskWriter::WritingBody(end)); + continue; + } + } + end_stream = end; + } + HttpTask::Trailer(_) | HttpTask::Done => { + end_stream = true; + } + HttpTask::Failed(e) => { + return Err(e); + } + } + } + + Ok(end_stream || self.body_writer.finished()) + } + /// Get the reference of the [Stream] that this HTTP session is operating upon. pub fn stream(&self) -> &Stream { &self.underlying_stream @@ -2320,6 +2598,30 @@ mod tests_stream { assert_eq!(wire_body.len(), n); } + #[tokio::test] + async fn body_bytes_sent_excludes_response_header() { + let read_wire = b"GET / HTTP/1.1\r\n\r\n"; + let wire_header = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\n"; + let wire_body = b"hello"; + let mock_io = Builder::new() + .read(read_wire) + .write(wire_header) + .write(wire_body) + .build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); + let mut new_response = ResponseHeader::build(StatusCode::OK, None).unwrap(); + new_response.append_header("Content-Length", "5").unwrap(); + http_stream.update_resp_headers = false; + http_stream + .write_response_header(Box::new(new_response)) + .await + .unwrap(); + assert_eq!(http_stream.body_bytes_sent(), 0); + http_stream.write_body(wire_body).await.unwrap(); + assert_eq!(http_stream.body_bytes_sent(), wire_body.len()); + } + #[tokio::test] async fn write_body_http10() { let read_wire = b"GET / HTTP/1.1\r\n\r\n"; @@ -2791,25 +3093,30 @@ mod test_sync { } #[cfg(test)] -mod test_timeouts { +mod test_proxy_tasks { use super::*; + use http::StatusCode; use std::future::IntoFuture; use tokio_test::io::{Builder, Mock}; - /// An upper limit for any read within any test to prevent tests from hanging forever if - /// an internal read call never returns, etc. + fn init_log() { + let _ = env_logger::builder().is_test(true).try_init(); + } + + // An upper limit for any read within any test to prevent tests from hanging forever if + // an internal read call never returns, etc. const TEST_MAX_WAIT_FOR_READ: Duration = Duration::from_secs(3); - /// The duration of 600 seconds is chosen to be "effectively forever" for the purpose of testing + // The duration of 600 seconds is chosen to be "effectively forever" for the purpose of testing const TEST_FOREVER_DURATION: Duration = Duration::from_secs(600); - /// The read_timeout to use, when we want to test that a read operation times out + // The read_timeout to use, when we want to test that a read operation times out const TEST_READ_TIMEOUT: Duration = Duration::from_secs(1); #[derive(Debug)] struct ReadBlockedForeverError; - /// Returns a client stream that will "never" send any bytes / return from a read operation + // Returns a client stream that will "never" send any bytes / return from a read operation fn mocked_blocking_headers_forever_stream() -> Box { Box::new(Builder::new().wait(TEST_FOREVER_DURATION).build()) } @@ -2826,8 +3133,8 @@ mod test_timeouts { ) } - /// Helper function to test a read operation with a tokio timeout - /// to prevent tests from hanging forever in case of a bug + // Helper function to test a read operation with a tokio timeout + // to prevent tests from hanging forever in case of a bug async fn test_read_with_tokio_timeout( read_future: F, ) -> Result>, ReadBlockedForeverError> @@ -2871,6 +3178,352 @@ mod test_timeouts { assert!(res.is_ok()); assert_eq!(res.unwrap().unwrap_err().etype(), &ReadTimedout); } + + #[tokio::test] + async fn test_send_proxy_task_and_write() { + init_log(); + + // We need to know exact bytes that will be written + // "HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello" + let expected_header = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\n"; + let expected_body = b"hello"; + + let mock_io = Builder::new() + .write(expected_header) + .write(expected_body) + .build(); + + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.update_resp_headers = false; // Disable automatic headers + + // Queue header task + let mut header = ResponseHeader::build(StatusCode::OK, Some(5)).unwrap(); + header.insert_header("Content-Length", "5").unwrap(); + http_stream.send_proxy_task(HttpTask::Header(Box::new(header), false)); + + // Queue body task + http_stream.send_proxy_task(HttpTask::Body(Some(Bytes::from("hello")), true)); + + // Write all tasks + let end_stream = http_stream.write_proxy_tasks().await.unwrap(); + assert!(end_stream); + } + + #[tokio::test] + async fn test_proxy_task_with_timeout() { + init_log(); + + let expected_header = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\n"; + let expected_body = b"hello"; + + let mock_io = Builder::new() + .write(expected_header) + .write(expected_body) + .build(); + + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.update_resp_headers = false; + http_stream.write_timeout = Some(Duration::from_secs(1)); // Set write timeout + + // Queue tasks + let mut header = ResponseHeader::build(StatusCode::OK, Some(5)).unwrap(); + header.insert_header("Content-Length", "5").unwrap(); + http_stream.send_proxy_task(HttpTask::Header(Box::new(header), false)); + http_stream.send_proxy_task(HttpTask::Body(Some(Bytes::from("hello")), true)); + + // Verify initial state + assert_eq!( + http_stream.body_bytes_sent(), + 0, + "Should start with 0 bytes sent" + ); + + // Write all tasks with timeout + let end_stream = http_stream.write_proxy_tasks().await.unwrap(); + assert!(end_stream); + + // Verify body bytes were counted correctly (not double counted) + assert_eq!( + http_stream.body_bytes_sent(), + 5, + "Should count exactly 5 bytes (application level), not double counted" + ); + } + + // Test that write_proxy_tasks is cancel-safe: if the future is dropped mid-execution, + // unwritten tasks should remain in the queue. + #[tokio::test] + async fn test_proxy_task_cancel_safety() { + init_log(); + + let expected_header = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n"; + // First chunk: "5\r\nhello\r\n" + let expected_chunk1 = b"5\r\nhello\r\n"; + + // Create a mock IO that will write the header and first chunk, + // but will block indefinitely on the second chunk + let mock_io = Builder::new() + .write(expected_header) + .write(expected_chunk1) + .wait(Duration::from_secs(999)) // This will cause timeout + .build(); + + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.update_resp_headers = false; + http_stream.write_timeout = Some(Duration::from_millis(100)); + + // Queue 3 tasks: header + 2 body chunks + let mut header = ResponseHeader::build(StatusCode::OK, None).unwrap(); + header + .insert_header("Transfer-Encoding", "chunked") + .unwrap(); + http_stream.send_proxy_task(HttpTask::Header(Box::new(header), false)); + http_stream.send_proxy_task(HttpTask::Body(Some(Bytes::from("hello")), false)); + http_stream.send_proxy_task(HttpTask::Body(Some(Bytes::from("world")), true)); + + // Verify we have 3 tasks queued + assert_eq!(http_stream.proxy_task_state.tasks.len(), 3); + + // Try to write all tasks - this should timeout while writing the second body chunk + let result = http_stream.write_proxy_tasks().await; + assert!(result.is_err()); + assert_eq!(result.unwrap_err().etype(), &WriteTimedout); + + // With the refactored cancel-safe design: + // - First task (header) was written successfully and removed from queue + // - Second task (first body "hello") was removed and sent to BodyWriter, write succeeded, state cleared + // - Third task (second body "world") was removed and sent to BodyWriter, timed out mid-write + // - The in-progress write state is tracked in current_writer, NOT in the queue + assert_eq!( + http_stream.proxy_task_state.tasks.len(), + 0, + "Queue should be empty - tasks are owned by writers once sent" + ); + + // The task being written should be tracked in current_writer + assert!( + matches!( + http_stream.proxy_task_state.current_writer, + Some(ProxyTaskWriter::WritingBody(_)) + ), + "Should be mid-write of body task - writer owns the 'world' task state" + ); + + // Verify body_bytes_sent only counts the successfully written "hello" (5 bytes) + // not the timed-out "world" + assert_eq!( + http_stream.body_bytes_sent(), + 5, + "Should only count the 5 bytes from 'hello', not the incomplete 'world' write" + ); + + // On next call to write_proxy_tasks(), Step 1 will resume the "world" write + } + + use crate::protocols::http::v1::test_util::FlushTrackingMock; + + // Test that write_continue_response can be called before write_proxy_tasks + // and both work correctly together. + #[tokio::test] + async fn test_continue_response_before_proxy_tasks() { + init_log(); + + // Expected bytes written: + // 1. 100 Continue response + // 2. 200 OK response header + // 3. Body data + let expected_continue = b"HTTP/1.1 100 Continue\r\n\r\n"; + let expected_header = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\n"; + let expected_body = b"hello"; + + let mock_io = Builder::new() + .write(expected_continue) + .write(expected_header) + .write(expected_body) + .build(); + + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.update_resp_headers = false; // Disable automatic headers + + // First, write the 100 Continue response + http_stream.write_continue_response().await.unwrap(); + + // Verify that 100 Continue was recorded + assert!( + http_stream.response_written().is_some(), + "100 Continue should be recorded in response_written" + ); + assert_eq!( + http_stream.response_written().unwrap().status, + StatusCode::CONTINUE, + "Should have recorded 100 Continue" + ); + + // Now queue the actual response using proxy tasks + let mut header = ResponseHeader::build(StatusCode::OK, Some(5)).unwrap(); + header.insert_header("Content-Length", "5").unwrap(); + http_stream.send_proxy_task(HttpTask::Header(Box::new(header), false)); + http_stream.send_proxy_task(HttpTask::Body(Some(Bytes::from("hello")), true)); + + // Write all proxy tasks + let end_stream = http_stream.write_proxy_tasks().await.unwrap(); + assert!(end_stream, "Should indicate end of stream"); + + // Verify final response is 200 OK, not 100 Continue + assert_eq!( + http_stream.response_written().unwrap().status, + StatusCode::OK, + "Final response should be 200 OK, overwriting 100 Continue" + ); + } + + #[tokio::test] + async fn test_head_response_with_content_length_flushes() { + init_log(); + + // HEAD request line + headers + let request = b"HEAD / HTTP/1.1\r\nHost: example.com\r\n\r\n"; + let expected_header = b"HTTP/1.1 200 OK\r\nContent-Length: 100\r\n\r\n"; + + let mock_io = Builder::new().read(request).write(expected_header).build(); + let (flush_mock, flush_count) = FlushTrackingMock::new(mock_io); + let mut http_stream = HttpSession::new(Box::new(flush_mock)); + http_stream.update_resp_headers = false; + + // Read the HEAD request + http_stream.read_request().await.unwrap(); + assert_eq!(http_stream.get_method(), Some(&Method::HEAD)); + + // Queue header with Content-Length (body will be empty for HEAD) + let mut header = ResponseHeader::build(StatusCode::OK, Some(2)).unwrap(); + header.insert_header("Content-Length", "100").unwrap(); + http_stream.send_proxy_task(HttpTask::Header(Box::new(header), true)); + + let flush_before = FlushTrackingMock::flush_count(&flush_count); + let end_stream = http_stream.write_proxy_tasks().await.unwrap(); + let flush_after = FlushTrackingMock::flush_count(&flush_count); + + assert!(end_stream, "HEAD response should be end of stream"); + assert!( + flush_after > flush_before, + "Should flush after writing HEAD response header with Content-Length \ + (body_writer.finished() is true). Got flush_before={flush_before}, \ + flush_after={flush_after}" + ); + } + + #[tokio::test] + async fn test_204_response_with_content_length_flushes() { + init_log(); + + let request = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"; + let expected_header = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n"; + + let mock_io = Builder::new().read(request).write(expected_header).build(); + let (flush_mock, flush_count) = FlushTrackingMock::new(mock_io); + let mut http_stream = HttpSession::new(Box::new(flush_mock)); + http_stream.update_resp_headers = false; + + http_stream.read_request().await.unwrap(); + + let mut header = ResponseHeader::build(StatusCode::NO_CONTENT, Some(2)).unwrap(); + header.insert_header("Content-Length", "0").unwrap(); + http_stream.send_proxy_task(HttpTask::Header(Box::new(header), true)); + + let flush_before = FlushTrackingMock::flush_count(&flush_count); + let end_stream = http_stream.write_proxy_tasks().await.unwrap(); + let flush_after = FlushTrackingMock::flush_count(&flush_count); + + assert!(end_stream, "204 response should be end of stream"); + assert!( + flush_after > flush_before, + "Should flush after writing 204 response header with Content-Length \ + (body_writer.finished() is true). Got flush_before={flush_before}, \ + flush_after={flush_after}" + ); + } + + #[tokio::test] + #[should_panic( + expected = "Unexpected UpgradedBody task received on un-upgraded downstream session" + )] + async fn test_upgraded_body_on_non_upgraded_session_panics() { + init_log(); + + let request = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"; + let expected_header = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n"; + // UpgradedBody on a non-upgraded session should panic before writing, + // but if the bug exists, BodyWriter would encode it as a chunk: + let expected_chunk = b"5\r\nhello\r\n"; + let expected_finish = b"0\r\n\r\n"; + + let mock_io = Builder::new() + .read(request) + .write(expected_header) + // If the panic check is missing, the body gets written as a chunk + .write(expected_chunk) + .write(expected_finish) + .build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.update_resp_headers = false; + + http_stream.read_request().await.unwrap(); + assert!( + !http_stream.was_upgraded(), + "Session should NOT be upgraded" + ); + + // Queue a normal header + let mut header = ResponseHeader::build(StatusCode::OK, Some(2)).unwrap(); + header + .insert_header("Transfer-Encoding", "chunked") + .unwrap(); + http_stream.send_proxy_task(HttpTask::Header(Box::new(header), false)); + + // Queue an UpgradedBody task on a non-upgraded session — should panic + http_stream.send_proxy_task(HttpTask::UpgradedBody(Some(Bytes::from("hello")), true)); + + // This should panic before/during the body write + let _ = http_stream.write_proxy_tasks().await; + } + + #[tokio::test] + #[should_panic(expected = "Unexpected Body task received on upgraded downstream session")] + async fn test_body_on_upgraded_session_panics() { + init_log(); + + // Upgrade request + let request = + b"GET / HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n"; + // 101 Switching Protocols response + let expected_header = + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n"; + // If the panic check is missing, Body data would be written raw (close-delimited) + let expected_body = b"hello"; + + let mock_io = Builder::new() + .read(request) + .write(expected_header) + .write(expected_body) + .build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.update_resp_headers = false; + + http_stream.read_request().await.unwrap(); + + // Queue 101 header to complete the upgrade + let mut header = ResponseHeader::build(StatusCode::SWITCHING_PROTOCOLS, Some(3)).unwrap(); + header.insert_header("Upgrade", "websocket").unwrap(); + header.insert_header("Connection", "Upgrade").unwrap(); + http_stream.send_proxy_task(HttpTask::Header(Box::new(header), false)); + + // Queue a regular Body task on what will be an upgraded session — should panic + http_stream.send_proxy_task(HttpTask::Body(Some(Bytes::from("hello")), true)); + + // This should panic (after writing the header, session becomes upgraded, + // then the Body task should be rejected) + let _ = http_stream.write_proxy_tasks().await; + } } #[cfg(test)] @@ -2980,3 +3633,116 @@ mod test_overread { assert!(reused.is_none()); } } + +#[cfg(test)] +mod test_abort_on_close { + use super::*; + use pingora_error::ErrorType; + use tokio_test::io::Builder; + + fn init_log() { + let _ = env_logger::builder().is_test(true).try_init(); + } + + /// Helper: create an HttpSession whose request has been read and body is done, + /// with the mock stream returning EOF on the next read (simulating client FIN). + async fn session_with_eof() -> HttpSession { + let request = b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n"; + let mock_io = Builder::new().read(&request[..]).build(); + let mut s = HttpSession::new(Box::new(mock_io)); + s.read_request().await.unwrap(); + s + } + + #[tokio::test] + async fn default_abort_on_close_returns_error() { + init_log(); + let mut s = session_with_eof().await; + + assert!(s.abort_on_close); + let err = s.read_body_or_idle(true).await.unwrap_err(); + assert_eq!(*err.etype(), ErrorType::ConnectionClosed); + assert!(s.is_half_closed()); + } + + #[tokio::test] + async fn abort_on_close_false_stays_pending() { + init_log(); + let mut s = session_with_eof().await; + s.set_abort_on_close(false); + + let result = tokio::time::timeout( + std::time::Duration::from_millis(50), + s.read_body_or_idle(true), + ) + .await; + + assert!(result.is_err(), "expected timeout (pending), got a result"); + assert!(s.is_half_closed()); + } + + #[tokio::test] + async fn abort_on_close_error_message_before_response() { + init_log(); + let mut s = session_with_eof().await; + + assert!(s.response_written().is_none()); + let err = s.read_body_or_idle(true).await.unwrap_err(); + let msg = format!("{err}"); + assert!( + msg.contains("Prematurely before response header is sent"), + "unexpected error message: {msg}" + ); + } + + #[tokio::test] + async fn abort_on_close_error_message_after_response_header() { + init_log(); + let mut s = session_with_eof().await; + + // Simulate that a response header has already been sent. + let resp = ResponseHeader::build(200, None).unwrap(); + s.response_written = Some(Box::new(resp)); + let err = s.read_body_or_idle(true).await.unwrap_err(); + let msg = format!("{err}"); + assert!( + msg.contains("Prematurely before response body is complete"), + "unexpected error message: {msg}" + ); + } + + #[tokio::test] + async fn no_body_expected_false_reads_body_then_idles() { + init_log(); + let request = b"POST / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 3\r\n\r\n"; + let mock_io = Builder::new().read(&request[..]).read(b"abc").build(); + let mut s = HttpSession::new(Box::new(mock_io)); + s.read_request().await.unwrap(); + + // 1) no_body_expected = false should still read request body while not done. + let body = s.read_body_or_idle(false).await.unwrap().unwrap(); + assert_eq!(body.as_ref(), b"abc"); + assert!(s.is_body_done()); + + // 2) Once body is naturally done, it transitions to idle behavior on the next call. + let err = s.read_body_or_idle(false).await.unwrap_err(); + assert_eq!(*err.etype(), ErrorType::ConnectionClosed); + let msg = format!("{err}"); + assert!( + msg.contains("Prematurely before response header is sent"), + "unexpected error message: {msg}" + ); + } + + #[tokio::test] + async fn set_abort_on_close_toggles() { + init_log(); + let mut s = session_with_eof().await; + + assert!(s.abort_on_close); + s.set_abort_on_close(false); + assert!(!s.abort_on_close); + s.set_abort_on_close(true); + assert!(s.abort_on_close); + } +} diff --git a/pingora-core/src/protocols/http/v2/mod.rs b/pingora-core/src/protocols/http/v2/mod.rs index 01711807..8f664c9d 100644 --- a/pingora-core/src/protocols/http/v2/mod.rs +++ b/pingora-core/src/protocols/http/v2/mod.rs @@ -93,13 +93,14 @@ mod test { use h2::frame::*; use http::{HeaderMap, Method, Uri}; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt, DuplexStream}; + use tokio::sync::{oneshot, watch}; use tokio_stream::StreamExt; use pingora_http::{RequestHeader, ResponseHeader}; use pingora_timeout::sleep; use crate::protocols::{ - http::v2::server::{handshake, HttpSession}, + http::v2::server::{self, handshake, HttpSession}, Digest, }; @@ -111,7 +112,10 @@ mod test { // Client handles.push(tokio::spawn(async move { - let conn = crate::connectors::http::v2::handshake(Box::new(client), 500, None) + use crate::connectors::http::v2::H2HandshakeSettings; + let mut settings = H2HandshakeSettings::new(); + settings.max_streams = 500; + let conn = crate::connectors::http::v2::handshake(Box::new(client), settings) .await .unwrap(); @@ -271,4 +275,350 @@ mod test { assert!(handle.await.is_ok()); } } + + #[tokio::test] + async fn test_graceful_shutdown_processes_inflight_stream() { + // HEADERS arrive on the server after the shutdown signal + // fires, but before the client has observed GOAWAY. + let (mut client, server) = duplex(65536); + // Use channels for deterministic timing. + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let (write_headers_tx, write_headers_rx) = oneshot::channel::<()>(); + + let client_handle = tokio::spawn(async move { + client + .write_all(b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") + .await + .unwrap(); + let mut codec: h2::Codec = h2::Codec::new(client); + codec.send(Settings::default().into()).await.unwrap(); + codec.send(Settings::ack().into()).await.unwrap(); + + // Wait until the test has triggered shutdown on the server before + // sending HEADERS. See the function-level comment for why. + write_headers_rx.await.unwrap(); + + let mut headers = Headers::new( + 1.into(), + Pseudo::request( + Method::GET, + Uri::from_static("https://one.one.one.one/"), + None, + ), + HeaderMap::new(), + ); + headers.set_end_headers(); + headers.set_end_stream(); + codec.send(headers.into()).await.unwrap(); + + let mut saw_response = false; + let mut saw_goaway = false; + let _ = pingora_timeout::timeout(Duration::from_secs(5), async { + while let Some(frame) = codec.next().await { + match frame.unwrap() { + h2::frame::Frame::Headers(_) => { + saw_response = true; + } + h2::frame::Frame::GoAway(_) => { + saw_goaway = true; + } + _ => {} + } + if saw_response && saw_goaway { + break; + } + } + }) + .await; + + assert!(saw_response, "expected response for stream 1"); + assert!(saw_goaway, "expected at least one GOAWAY frame"); + }); + + let connection = handshake(Box::new(server), None).await.unwrap(); + let digest = Arc::new(Digest::default()); + + let trigger = tokio::spawn(async move { + sleep(Duration::from_millis(50)).await; + shutdown_tx.send(true).unwrap(); + // Wait long enough that the server task is guaranteed to have + // observed the shutdown signal and committed to its post-shutdown + // path before putting anything on the wire. + sleep(Duration::from_millis(50)).await; + write_headers_tx.send(()).unwrap(); + }); + + let mut session_handles = vec![]; + server::accept_downstream_sessions(connection, digest, shutdown_rx, |mut session| { + session_handles.push(tokio::spawn(async move { + let req = session.req_header(); + assert_eq!(req.method, Method::GET); + let resp = Box::new(ResponseHeader::build(200, None).unwrap()); + session.write_response_header(resp, true).unwrap(); + })); + }) + .await; + + trigger.await.unwrap(); + assert_eq!( + session_handles.len(), + 1, + "expected stream 1 to be surfaced after shutdown_initiated" + ); + for h in session_handles { + h.await.unwrap(); + } + client_handle.await.unwrap(); + } + + #[tokio::test] + async fn test_graceful_shutdown_processes_post_goaway_stream() { + // Client opens stream 1 after it has observed the + // server's GOAWAY frame. Stream 1 is below the GOAWAY(MAX) + // last_stream_id, so per RFC 9113 §6.8 the server must still process + // it. + let (mut client, server) = duplex(65536); + let (shutdown_tx, shutdown_rx) = watch::channel(false); + + let client_handle = tokio::spawn(async move { + client + .write_all(b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") + .await + .unwrap(); + let mut codec: h2::Codec = h2::Codec::new(client); + codec.send(Settings::default().into()).await.unwrap(); + codec.send(Settings::ack().into()).await.unwrap(); + + // Block until the server's GOAWAY is observed. HEADERS for + // stream 1 are sent strictly after this point. + let mut saw_goaway_before_headers = false; + while let Some(frame) = codec.next().await { + if matches!(frame.unwrap(), h2::frame::Frame::GoAway(_)) { + saw_goaway_before_headers = true; + break; + } + } + assert!( + saw_goaway_before_headers, + "expected GOAWAY before sending HEADERS for stream 1" + ); + + let mut headers = Headers::new( + 1.into(), + Pseudo::request( + Method::GET, + Uri::from_static("https://one.one.one.one/"), + None, + ), + HeaderMap::new(), + ); + headers.set_end_headers(); + headers.set_end_stream(); + codec.send(headers.into()).await.unwrap(); + + let mut saw_response_for_stream_1 = false; + let _ = pingora_timeout::timeout(Duration::from_secs(5), async { + while let Some(frame) = codec.next().await { + if let Ok(h2::frame::Frame::Headers(h)) = frame { + if h.stream_id() == 1u32 { + saw_response_for_stream_1 = true; + break; + } + } + } + }) + .await; + assert!( + saw_response_for_stream_1, + "expected response on stream 1 after GOAWAY", + ); + }); + + let connection = handshake(Box::new(server), None).await.unwrap(); + let digest = Arc::new(Digest::default()); + + let trigger = tokio::spawn(async move { + sleep(Duration::from_millis(50)).await; + shutdown_tx.send(true).unwrap(); + }); + + let mut session_handles = vec![]; + server::accept_downstream_sessions(connection, digest, shutdown_rx, |mut session| { + session_handles.push(tokio::spawn(async move { + let resp = Box::new(ResponseHeader::build(200, None).unwrap()); + session.write_response_header(resp, true).unwrap(); + })); + }) + .await; + + trigger.await.unwrap(); + assert_eq!( + session_handles.len(), + 1, + "expected exactly one stream surfaced after GOAWAY" + ); + for h in session_handles { + h.await.unwrap(); + } + client_handle.await.unwrap(); + } + + #[tokio::test] + async fn test_graceful_shutdown_idle_connection_exits_promptly() { + let (mut client, server) = duplex(65536); + let (shutdown_tx, shutdown_rx) = watch::channel(false); + + let client_handle = tokio::spawn(async move { + client + .write_all(b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") + .await + .unwrap(); + let mut codec: h2::Codec = h2::Codec::new(client); + codec.send(Settings::default().into()).await.unwrap(); + codec.send(Settings::ack().into()).await.unwrap(); + + // Wait for the server's GOAWAY, then drop the codec to close the + // connection so the accept loop can exit. + let mut saw_goaway = false; + let _ = pingora_timeout::timeout(Duration::from_secs(3), async { + while let Some(frame) = codec.next().await { + if matches!(frame.unwrap(), h2::frame::Frame::GoAway(_)) { + saw_goaway = true; + break; + } + } + }) + .await; + assert!(saw_goaway, "expected GOAWAY"); + }); + + let connection = handshake(Box::new(server), None).await.unwrap(); + let digest = Arc::new(Digest::default()); + + let trigger = tokio::spawn(async move { + sleep(Duration::from_millis(20)).await; + shutdown_tx.send(true).unwrap(); + }); + + let result = pingora_timeout::timeout( + Duration::from_secs(2), + server::accept_downstream_sessions(connection, digest, shutdown_rx, |_session| { + panic!("did not expect any sessions on an idle connection"); + }), + ) + .await; + assert!(result.is_ok(), "accept loop hung after shutdown"); + + trigger.await.unwrap(); + client_handle.await.unwrap(); + } + + #[tokio::test] + async fn test_graceful_shutdown_refuses_stream_above_last_stream_id() { + // After the server commits to a final last_stream_id and emits the + // closing GOAWAY, any stream the client tries to open above that id + // must be refused. The accept loop must not surface it and must exit + // cleanly via the `Ok(None)` arm of `from_h2_conn`. + let (mut client, server) = duplex(65536); + let (shutdown_tx, shutdown_rx) = watch::channel(false); + + let client_handle = tokio::spawn(async move { + client + .write_all(b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") + .await + .unwrap(); + let mut codec: h2::Codec = h2::Codec::new(client); + codec.send(Settings::default().into()).await.unwrap(); + codec.send(Settings::ack().into()).await.unwrap(); + + // Open stream 1 so the server will commit last_stream_id >= 1 + // when it eventually emits its closing GOAWAY. + let mut headers = Headers::new( + 1.into(), + Pseudo::request( + Method::GET, + Uri::from_static("https://one.one.one.one"), + None, + ), + HeaderMap::new(), + ); + headers.set_end_headers(); + headers.set_end_stream(); + codec.send(headers.into()).await.unwrap(); + + // Drain frames from the server. Break once we've seen the + // response and at least one GOAWAY so the test doesn't race + // its own outer timeout while waiting on a quiet codec. + let mut saw_response = false; + let mut saw_goaway = false; + let _ = pingora_timeout::timeout(Duration::from_secs(3), async { + while let Some(frame) = codec.next().await { + match frame { + Ok(h2::frame::Frame::Headers(h)) if h.stream_id() == 1 => { + saw_response = true; + } + Ok(h2::frame::Frame::GoAway(_)) => { + saw_goaway = true; + } + Ok(_) => {} + Err(_) => break, + } + if saw_response && saw_goaway { + break; + } + } + }) + .await; + assert!(saw_response, "expected response for stream 1"); + assert!(saw_goaway, "expected at least one GOAWAY frame"); + + // Try to open stream 3 (above last_stream_id). The send may + // succeed locally (duplex buffer) or fail (peer half closed); + // either way the server-side codec must not surface the stream. + let mut headers = Headers::new( + 3.into(), + Pseudo::request( + Method::GET, + Uri::from_static("https://one.one.one.one"), + None, + ), + HeaderMap::new(), + ); + headers.set_end_headers(); + headers.set_end_stream(); + let _ = codec.send(headers.into()).await; + }); + + let connection = handshake(Box::new(server), None).await.unwrap(); + let digest = Arc::new(Digest::default()); + + let trigger = tokio::spawn(async move { + sleep(Duration::from_millis(50)).await; + shutdown_tx.send(true).unwrap(); + }); + + let mut session_handles = vec![]; + let result = pingora_timeout::timeout( + Duration::from_secs(5), + server::accept_downstream_sessions(connection, digest, shutdown_rx, |mut session| { + session_handles.push(tokio::spawn(async move { + let resp = Box::new(ResponseHeader::build(200, None).unwrap()); + session.write_response_header(resp, true).unwrap(); + })); + }), + ) + .await; + assert!(result.is_ok(), "accept loop hung after shutdown"); + assert_eq!( + session_handles.len(), + 1, + "only stream 1 may be surfaced; streams above last_stream_id must be refused" + ); + + trigger.await.unwrap(); + for h in session_handles { + h.await.unwrap(); + } + client_handle.await.unwrap(); + } } diff --git a/pingora-core/src/protocols/http/v2/server.rs b/pingora-core/src/protocols/http/v2/server.rs index 363b7357..604d53c6 100644 --- a/pingora-core/src/protocols/http/v2/server.rs +++ b/pingora-core/src/protocols/http/v2/server.rs @@ -34,6 +34,7 @@ use crate::protocols::http::date::get_cached_date; use crate::protocols::http::v1::client::http_req_header_to_wire; use crate::protocols::http::HttpTask; use crate::protocols::{Digest, SocketAddr, Stream}; +use crate::server::ShutdownWatch; use crate::{Error, ErrorType, OrErr, Result}; const BODY_BUF_LIMIT: usize = 1024 * 64; @@ -63,6 +64,77 @@ pub async fn handshake(io: Stream, options: Option) -> Result( + mut conn: H2Connection, + digest: Arc, + mut shutdown: ShutdownWatch, + mut on_session: F, +) where + F: FnMut(HttpSession), +{ + let mut shutdown_initiated = false; + loop { + let h2_stream = if shutdown_initiated { + HttpSession::from_h2_conn(&mut conn, digest.clone()).await + } else { + tokio::select! { + // Poll the shutdown signal first so a concurrent signal is + // observed deterministically. `from_h2_conn` is cancel-safe + // and is polled again on the next iteration. + biased; + _ = shutdown.changed() => { + conn.graceful_shutdown(); + shutdown_initiated = true; + continue; + } + h2_stream = HttpSession::from_h2_conn(&mut conn, digest.clone()) => h2_stream, + } + }; + match h2_stream { + Err(e) => { + // It is common for the client to just disconnect TCP without + // properly closing H2. So we don't log the errors here + debug!("H2 error when accepting new stream {e}"); + return; + } + // None means the connection is ready to be closed + Ok(None) => return, + Ok(Some(session)) => on_session(session), + } + } +} + use futures::task::Context; use futures::task::Poll; use std::pin::Pin; diff --git a/pingora-core/src/protocols/l4/listener.rs b/pingora-core/src/protocols/l4/listener.rs index 7d00005e..a6055267 100644 --- a/pingora-core/src/protocols/l4/listener.rs +++ b/pingora-core/src/protocols/l4/listener.rs @@ -67,6 +67,20 @@ impl AsRawSocket for Listener { } impl Listener { + /// Return the local address this listener is bound to. + /// + /// For TCP listeners this is the resolved address (including the + /// OS-assigned port when the listener was bound to port 0). + /// Returns `None` for non-TCP listeners (e.g. Unix domain sockets). + #[cfg(test)] + pub fn local_addr(&self) -> Option { + match self { + Self::Tcp(l) => l.local_addr().ok(), + #[cfg(unix)] + Self::Unix(_) => None, + } + } + /// Accept a connection from the listening endpoint pub async fn accept(&self) -> io::Result { match &self { diff --git a/pingora-core/src/protocols/l4/stream.rs b/pingora-core/src/protocols/l4/stream.rs index 4aa70f70..7cbbd37c 100644 --- a/pingora-core/src/protocols/l4/stream.rs +++ b/pingora-core/src/protocols/l4/stream.rs @@ -354,14 +354,69 @@ impl AsRawSocket for RawStreamWrapper { } } -// Large read buffering helps reducing syscalls with little trade-off -// Ssl layer always does "small" reads in 16k (TLS record size) so L4 read buffer helps a lot. -const BUF_READ_SIZE: usize = 64 * 1024; -// Small write buf to match MSS. Too large write buf delays real time communication. -// This buffering effectively implements something similar to Nagle's algorithm. -// The benefit is that user space can control when to flush, where Nagle's can't be controlled. -// And userspace buffering reduce both syscalls and small packets. -const BUF_WRITE_SIZE: usize = 1460; +/// The default L4 read buffer size. +/// +/// Large read buffering helps reducing syscalls with little trade-off. The SSL +/// layer always does "small" reads in 16k chunks (TLS record size), so L4 read +/// buffering helps a lot. +pub const DEFAULT_L4_READ_BUFFER_SIZE: usize = 64 * 1024; + +/// The default L4 write buffer size. +/// +/// Small write buffering matches a typical MSS. Too large a write buffer delays +/// real-time communication. This buffering effectively implements something +/// similar to Nagle's algorithm, but user space can control when to flush. +pub const DEFAULT_L4_WRITE_BUFFER_SIZE: usize = 1460; + +/// L4 [`BufStream`] buffer sizing. +/// +/// Leaving either side as `None` preserves Pingora's default for that side. +/// Setting either side to `Some(0)` disables `BufStream` buffering for that +/// direction. +#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] +pub struct L4BufferSettings { + /// Read buffer size in bytes. `None` uses [`DEFAULT_L4_READ_BUFFER_SIZE`]. + pub read: Option, + /// Write buffer size in bytes. `None` uses [`DEFAULT_L4_WRITE_BUFFER_SIZE`]. + pub write: Option, +} + +impl L4BufferSettings { + /// Create settings with both read and write buffer sizes set explicitly. + pub fn new(read: usize, write: usize) -> Self { + Self { + read: Some(read), + write: Some(write), + } + } + + /// Create settings that disable both read and write `BufStream` buffering. + pub fn unbuffered() -> Self { + Self::new(0, 0) + } + + /// Set the read buffer size. + pub fn read(mut self, read: usize) -> Self { + self.read = Some(read); + self + } + + /// Set the write buffer size. + pub fn write(mut self, write: usize) -> Self { + self.write = Some(write); + self + } + + /// Resolved read buffer size after applying defaults. + pub fn read_capacity(&self) -> usize { + self.read.unwrap_or(DEFAULT_L4_READ_BUFFER_SIZE) + } + + /// Resolved write buffer size after applying defaults. + pub fn write_capacity(&self) -> usize { + self.write.unwrap_or(DEFAULT_L4_WRITE_BUFFER_SIZE) + } +} // NOTE: with writer buffering, users need to call flush() to make sure the data is actually // sent. Otherwise data could be stuck in the buffer forever or get lost when stream is closed. @@ -456,13 +511,18 @@ impl Stream { /// Set the buffer of BufStream /// It is only set later because of the malloc overhead in critical accept() path - pub(crate) fn set_buffer(&mut self) { + pub(crate) fn set_buffer(&mut self, buffer: L4BufferSettings) { use std::mem; // Since BufStream doesn't provide an API to adjust the buf directly, // we take the raw stream out of it and put it in a new BufStream with the size we want let stream = mem::take(&mut self.stream); - let stream = - stream.map(|s| BufStream::with_capacity(BUF_READ_SIZE, BUF_WRITE_SIZE, s.into_inner())); + let stream = stream.map(|s| { + BufStream::with_capacity( + buffer.read_capacity(), + buffer.write_capacity(), + s.into_inner(), + ) + }); let _ = mem::replace(&mut self.stream, stream); } } @@ -814,14 +874,67 @@ pub mod async_write_vec { fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { let me = &mut *self; - while me.buf.has_remaining() { - let n = ready!(Pin::new(&mut *me.writer).poll_write_vec(ctx, me.buf))?; - if n == 0 { - return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); - } + poll_write_vec_all_buf(ctx, Pin::new(&mut *me.writer), me.buf) + } + } + + /// Primitive poll function to write ALL bytes from a buffer using vectored writes. + /// Keeps polling `poll_write_vec` until the entire buffer is written. + /// The buffer is advanced as bytes are written. + /// + /// Returns Poll::Ready(Ok(())) when all bytes are written. + /// Returns WriteZero error if poll_write_vec returns 0. + /// + /// This is essentially a polling form of tokio's + /// [`write_all_buf`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncWriteExt.html#method.write_all_buf). + // TODO: we should be able to switch over to polling the future from tokio AsyncWriteExt directly, + // for now we continue to use the old trait. + pub fn poll_write_vec_all_buf( + ctx: &mut Context<'_>, + mut writer: Pin<&mut W>, + buf: &mut B, + ) -> Poll> + where + W: AsyncWriteVec + ?Sized, + B: Buf, + { + while buf.has_remaining() { + let n = ready!(writer.as_mut().poll_write_vec(ctx, buf))?; + if n == 0 { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); } - Poll::Ready(Ok(())) } + Poll::Ready(Ok(())) + } + + /// Primitive poll function to write ALL bytes from a buffer using regular writes. + /// Keeps polling `poll_write` until the entire buffer is written. + /// The buffer is advanced as bytes are written. + /// + /// Returns Poll::Ready(Ok(())) when all bytes are written. + /// Returns WriteZero error if poll_write returns 0. + /// + /// This is essentially a polling form of tokio's + /// [`write_all_buf`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncWriteExt.html#method.write_all_buf) + /// though we explicitly use non-vectored writes in this case for strict parity with the + /// original `write_all` method. + pub fn poll_write_all_buf( + ctx: &mut Context<'_>, + mut writer: Pin<&mut W>, + buf: &mut B, + ) -> Poll> + where + W: AsyncWrite + ?Sized, + B: Buf, + { + while buf.has_remaining() { + let n = ready!(writer.as_mut().poll_write(ctx, buf.chunk()))?; + if n == 0 { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } + buf.advance(n); + } + Poll::Ready(Ok(())) } /* from https://github.com/tokio-rs/tokio/blob/master/tokio-util/src/lib.rs#L177 */ diff --git a/pingora-core/src/protocols/proxy_protocol.rs b/pingora-core/src/protocols/proxy_protocol.rs index a6700993..bcddd4bf 100644 --- a/pingora-core/src/protocols/proxy_protocol.rs +++ b/pingora-core/src/protocols/proxy_protocol.rs @@ -26,6 +26,12 @@ use crate::protocols::l4::stream::Stream; /// Re-export the parsed header type from the underlying `proxy_protocol` crate for convenience. pub use proxy_protocol::ProxyHeader; +/// Re-export the v2 extension TLV enum so callers of +/// `set_proxy_v2_tlv_callback` (registered in `listeners::mod`) can +/// pattern-match on `ExtensionTlv::Custom { type_id, value }` for +/// application-defined TLVs in the `0xE0..=0xEF` range. +pub use proxy_protocol::version2::ExtensionTlv; + /// Maximum number of bytes a PROXY protocol v1 header can occupy, including CRLF. pub const MAX_PROXY_V1_HEADER_LEN: usize = 108; /// Maximum number of bytes a PROXY protocol v2 header can occupy (16 bytes header + 64k body). diff --git a/pingora-core/src/server/bootstrap_services.rs b/pingora-core/src/server/bootstrap_services.rs index 0ad27ffc..74c81d79 100644 --- a/pingora-core/src/server/bootstrap_services.rs +++ b/pingora-core/src/server/bootstrap_services.rs @@ -18,13 +18,17 @@ use async_trait::async_trait; use log::{debug, error, info}; use parking_lot::Mutex; use std::sync::Arc; -use tokio::sync::{broadcast, Mutex as TokioMutex}; +use tokio::sync::broadcast; #[cfg(feature = "sentry")] use sentry::ClientOptions; +#[cfg(unix)] +use crate::server::daemon::notify_parent_ready_for_fds; #[cfg(unix)] use crate::server::ListenFds; +#[cfg(unix)] +use std::time::Duration; use crate::{ prelude::Opt, @@ -32,6 +36,10 @@ use crate::{ services::{background::BackgroundService, ServiceReadyNotifier}, }; +/// Default timeout for retrying `SIGUSR1` to the parent when it fails with `EPERM`. +#[cfg(unix)] +const DEFAULT_DAEMON_NOTIFY_TIMEOUT: Duration = Duration::from_secs(60); + /// Service that allows the bootstrap process to be delayed until after /// dependencies are ready pub struct BootstrapService { @@ -58,7 +66,17 @@ pub struct Bootstrap { execution_phase_watch: broadcast::Sender, #[cfg(unix)] - listen_fds: Option, + listen_fds: ListenFds, + + /// PID of the original parent process to notify via `SIGUSR1` after bootstrap completes. + /// Set when [`ServerConf::daemon_wait_for_ready`] is `true`. + #[cfg(unix)] + notify_parent_pid: Option, + + /// How long to keep retrying `SIGUSR1` to the parent when it fails with `EPERM`. + /// See [`ServerConf::daemon_notify_timeout_seconds`]. + #[cfg(unix)] + daemon_notify_timeout: std::time::Duration, #[cfg(feature = "sentry")] #[cfg_attr(docsrs, doc(cfg(feature = "sentry")))] @@ -95,7 +113,14 @@ impl Bootstrap { upgrade, upgrade_sock, #[cfg(unix)] - listen_fds: None, + listen_fds: Arc::new(Mutex::new(Fds::new())), + #[cfg(unix)] + notify_parent_pid: None, + #[cfg(unix)] + daemon_notify_timeout: conf + .daemon_notify_timeout_seconds + .map(|n| Duration::from_secs(n.get())) + .unwrap_or(DEFAULT_DAEMON_NOTIFY_TIMEOUT), execution_phase_watch: execution_phase_watch.clone(), completed: false, #[cfg(feature = "sentry")] @@ -110,6 +135,13 @@ impl Bootstrap { self.sentry = sentry_config; } + /// Store the parent process PID to notify via `SIGUSR1` after bootstrap completes. + /// Only relevant when [`ServerConf::daemon_wait_for_ready`] is `true`. + #[cfg(unix)] + pub fn set_notify_parent_pid(&mut self, pid: u32) { + self.notify_parent_pid = Some(pid); + } + /// Initialize the Sentry client from the configured [`ClientOptions`] and /// store the resulting guard. /// @@ -161,7 +193,17 @@ impl Bootstrap { std::process::exit(0); } - // load fds + // Notify the parent process that it can exit. It might seem like we should load the file + // descriptors from the old process first, but the purpose of this notification is to + // release the parent so that the process managing it (e.g. systemd) can continue and send + // a quit signal to the old process. That quit signal is required before the old process + // will start trying to send its file descriptors to us — so if we called load_fds first, + // we would be guaranteeing a timeout. + #[cfg(unix)] + if let Some(pid) = self.notify_parent_pid { + notify_parent_ready_for_fds(pid, self.daemon_notify_timeout); + } + #[cfg(unix)] match self.load_fds(self.upgrade) { Ok(_) => { @@ -186,17 +228,18 @@ impl Bootstrap { #[cfg(unix)] fn load_fds(&mut self, upgrade: bool) -> Result<(), nix::Error> { - let mut fds = Fds::new(); if upgrade { debug!("Trying to receive socks"); - fds.get_from_sock(self.upgrade_sock.as_str())? + let mut fds = Fds::new(); + fds.get_from_sock(self.upgrade_sock.as_str())?; + // Mutate through the existing Arc so all clones held by services see the update. + *self.listen_fds.lock() = fds; } - self.listen_fds = Some(Arc::new(TokioMutex::new(fds))); Ok(()) } #[cfg(unix)] - pub fn get_fds(&self) -> Option { + pub fn get_fds(&self) -> ListenFds { self.listen_fds.clone() } } diff --git a/pingora-core/src/server/configuration/mod.rs b/pingora-core/src/server/configuration/mod.rs index acfc1b21..3e56de4c 100644 --- a/pingora-core/src/server/configuration/mod.rs +++ b/pingora-core/src/server/configuration/mod.rs @@ -25,6 +25,8 @@ use pingora_error::{Error, ErrorType::*, OrErr, Result}; use serde::{Deserialize, Serialize}; use std::ffi::OsString; use std::fs; +use std::num::NonZeroU64; +use std::path::PathBuf; // default maximum upstream retries for retry-able proxy errors const DEFAULT_MAX_RETRIES: usize = 16; @@ -63,6 +65,11 @@ pub struct ServerConf { pub user: Option, /// Similar to `user`, the group this process should switch to. pub group: Option, + /// Working directory for the daemonized process. + /// + /// Only applied when `daemon` is `true`; set this to start the daemon from a known cwd. + // TODO: other OS path options should likely be `PathBuf` as well. + pub working_directory: Option, /// How many threads **each** service should get. The threads are not shared across services. pub threads: usize, /// Number of listener tasks to use per fd. This allows for parallel accepts. @@ -129,6 +136,45 @@ pub struct ServerConf { /// /// When not set, the tokio default (10 seconds) is used. pub blocking_threads_ttl_seconds: Option, + /// When `daemon` is `true`, controls whether the parent process of the daemon fork waits for + /// the child to signal readiness before exiting. + /// + /// When `false` (default), the parent exits immediately after the daemon fork, matching the + /// traditional daemonization behavior. Systemd will consider the service started as soon as + /// the parent exits, which may be before the child has finished bootstrapping. + /// + /// When `true`, the parent waits (up to [`Self::daemon_ready_timeout_seconds`]) for the child + /// to send `SIGUSR1` after bootstrap completes. This causes systemd to delay any subsequent + /// steps (such as sending `SIGQUIT` to the old process) until the new instance is fully ready + /// to serve traffic. If the child does not signal in time, the parent exits with a non-zero + /// exit code, causing systemd to abort the reload. + pub daemon_wait_for_ready: bool, + /// Timeout in seconds for the parent process to wait for the child to signal readiness during + /// daemonization when [`Self::daemon_wait_for_ready`] is `true`. + /// + /// If the child does not send `SIGUSR1` within this timeout, the parent exits with a non-zero + /// exit code. + /// + /// Defaults to 600 seconds (10 minutes). + pub daemon_ready_timeout_seconds: Option, + /// How long the child process will keep retrying `SIGUSR1` to the parent when the signal + /// fails with a permission error (`EPERM`) during daemonization. + /// + /// After the daemon fork, the parent always drops its credentials to the configured user and + /// group (see [`Self::user`], [`Self::group`]). Because the privilege drop happens after the + /// fork, there is a small window where the child may attempt to signal the parent before the + /// parent has finished changing its credentials. During this window the kernel will reject the + /// signal with `EPERM` because the child and parent are running as different users. The child + /// retries every 100 ms until this timeout elapses. + /// + /// In practice this window is very small, so the default of 60 seconds is far more than + /// enough to account for it. + /// + /// Only retries on `EPERM`; any other error (e.g. `ESRCH` — parent no longer exists) is + /// treated as fatal and logged without retrying. + /// + /// Defaults to 60 seconds. + pub daemon_notify_timeout_seconds: Option, } impl Default for ServerConf { @@ -148,6 +194,7 @@ impl Default for ServerConf { upgrade_sock: "/tmp/pingora_upgrade.sock".to_string(), user: None, group: None, + working_directory: None, threads: 1, listener_tasks_per_fd: 1, work_stealing: true, @@ -160,6 +207,9 @@ impl Default for ServerConf { upgrade_sock_connect_accept_max_retries: None, max_blocking_threads: None, blocking_threads_ttl_seconds: None, + daemon_ready_timeout_seconds: None, + daemon_wait_for_ready: false, + daemon_notify_timeout_seconds: None, } } } @@ -320,6 +370,7 @@ mod tests { upgrade_sock: "".to_string(), user: None, group: None, + working_directory: None, threads: 1, listener_tasks_per_fd: 1, work_stealing: true, @@ -332,6 +383,9 @@ mod tests { upgrade_sock_connect_accept_max_retries: None, max_blocking_threads: None, blocking_threads_ttl_seconds: None, + daemon_ready_timeout_seconds: None, + daemon_wait_for_ready: false, + daemon_notify_timeout_seconds: None, }; // cargo test -- --nocapture not_a_test_i_cannot_write_yaml_by_hand println!("{}", conf.to_yaml()); @@ -372,6 +426,29 @@ version: 1 assert!(!conf.enable_proxy_protocol); } + #[test] + fn test_working_directory_deserializes_from_yaml_string() { + init_log(); + let conf_str = r#" +--- +version: 1 +daemon: true +working_directory: /var/lib/pingora + "#; + + let conf = ServerConf::from_yaml(conf_str).unwrap(); + assert_eq!( + conf.working_directory.as_deref(), + Some(std::path::Path::new("/var/lib/pingora")) + ); + + let yaml = serde_yaml::to_value(&conf).unwrap(); + assert_eq!( + yaml.get("working_directory"), + Some(&serde_yaml::Value::String("/var/lib/pingora".to_string())) + ); + } + #[test] fn test_zero_max_blocking_threads_is_rejected() { init_log(); diff --git a/pingora-core/src/server/daemon.rs b/pingora-core/src/server/daemon.rs index 7381fc93..d225ca22 100644 --- a/pingora-core/src/server/daemon.rs +++ b/pingora-core/src/server/daemon.rs @@ -12,18 +12,71 @@ // See the License for the specific language governing permissions and // limitations under the License. -use daemonize::{Daemonize, Stdio}; -use log::{debug, error}; +use daemonize::{Daemonize, Outcome, Stdio}; +use log::{debug, error, info}; +use pingora_error::{Error, ErrorType, OrErr, Result}; use std::ffi::CString; use std::fs::{self, OpenOptions}; use std::os::unix::prelude::OpenOptionsExt; use std::path::Path; +use std::process; +use std::thread; +use std::time::{Duration, Instant}; use crate::server::configuration::ServerConf; +/// Error returned by [`send_signal`]. +#[derive(Debug)] +pub(crate) enum SignalError { + /// The caller does not have permission to send the signal to the target process (`EPERM`). + PermissionDenied, + /// Any other error from `kill(2)`. Contains the raw `errno` value. + OtherSignalError(i32), +} + +impl std::fmt::Display for SignalError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SignalError::PermissionDenied => write!(f, "permission denied (EPERM)"), + SignalError::OtherSignalError(errno) => { + write!(f, "kill failed with errno {errno}") + } + } + } +} + +/// Send `signal` to the process identified by `pid`. +/// +/// Returns `Ok(())` on success. On failure, maps `errno` to [`SignalError`]: +/// - `EPERM` → [`SignalError::PermissionDenied`] +/// - anything else → [`SignalError::OtherSignalError`] containing the raw errno value. +fn send_signal(pid: libc::pid_t, signal: libc::c_int) -> Result<(), SignalError> { + // SAFETY: `kill(2)` is safe to call with any pid/signal combination — invalid values + // simply return an error via errno rather than causing undefined behavior. + let ret = unsafe { libc::kill(pid, signal) }; + if ret == 0 { + return Ok(()); + } + let errno = std::io::Error::last_os_error().raw_os_error().unwrap_or(-1); + if errno == libc::EPERM { + Err(SignalError::PermissionDenied) + } else { + Err(SignalError::OtherSignalError(errno)) + } +} + // Utilities to daemonize a pingora server, i.e. run the process in the background, possibly // under a different running user and/or group. +/// Default timeout for the parent to wait for the daemon child to signal readiness. +const DEFAULT_DAEMON_READY_TIMEOUT: Duration = Duration::from_secs(600); + +/// How long to sleep between `SIGUSR1` send attempts when `EPERM` is returned. +const NOTIFY_RETRY_INTERVAL: Duration = Duration::from_millis(100); + +/// How long to sleep between pid-file liveness checks in the async wait loop. +const LIVENESS_CHECK_INTERVAL: Duration = Duration::from_millis(100); + // XXX: this operation should have been done when the old service is exiting. // Now the new pid file just kick the old one out of the way fn move_old_pid(path: &str) { @@ -45,7 +98,15 @@ fn move_old_pid(path: &str) { } } +/// # Safety +/// +/// `name` must be a valid, null-terminated C string. The returned `gid_t` is read from the +/// `passwd` struct returned by `getpwnam(3)`, which points to a static buffer that may be +/// overwritten by subsequent calls to `getpwnam` or `getpwuid`. The caller must not hold the +/// pointer across such calls. unsafe fn gid_for_username(name: &CString) -> Option { + // SAFETY: `name` is a valid CString; `getpwnam` returns a pointer to a static buffer + // or null. We read `pw_gid` immediately and do not retain the pointer. let passwd = libc::getpwnam(name.as_ptr() as *const libc::c_char); if !passwd.is_null() { return Some((*passwd).pw_gid); @@ -53,15 +114,287 @@ unsafe fn gid_for_username(name: &CString) -> Option { None } +/// Drop the parent process's UID to the user specified in [`ServerConf::user`]. +/// +/// The kernel only permits a process to send a signal to another if they share the same UID (or +/// the sender is root). Since the daemon child sends `SIGUSR1` to the parent to signal readiness, +/// the parent must be running as the same UID as the child by the time that signal arrives — +/// otherwise the kernel will reject it with `EPERM`. +/// +/// This function is called in the `Outcome::Parent` path immediately after `execute()` returns, +/// before the parent enters its readiness wait loop, so the parent's UID matches the child's as +/// quickly as possible after the fork. +/// +/// Only the UID is changed; the GID is left as-is. Signal permission checks are based on UID, +/// so changing the GID is not necessary for this purpose. +/// +/// Logs an error and continues if the user cannot be resolved or `setuid` fails — the parent +/// is short-lived and about to exit, so a failed privilege drop is non-fatal. The child's +/// `EPERM` retry window (see [`ServerConf::daemon_notify_timeout_seconds`]) exists precisely to +/// cover the small gap between the fork and the parent completing this UID change. +fn drop_privileges_in_parent(conf: &ServerConf) -> Result<()> { + let Some(user) = conf.user.as_ref() else { + return Ok(()); + }; + + let user_cstr = CString::new(user.as_str()).or_err_with(ErrorType::Custom("Daemon"), || { + format!("drop_privileges_in_parent: user '{user}' invalid") + })?; + + // SAFETY: `user_cstr` is a valid CString. `getpwnam` returns a pointer to a static + // buffer or null. We read `pw_uid` immediately and do not retain the pointer. + let passwd = unsafe { libc::getpwnam(user_cstr.as_ptr() as *const libc::c_char) }; + if passwd.is_null() { + return Error::e_explain( + ErrorType::Custom("Daemon"), + format!("drop_privileges_in_parent: user '{user}' not found"), + ); + } + + // SAFETY: `passwd` was checked for null above. We dereference it once to read `pw_uid`. + let uid = unsafe { (*passwd).pw_uid }; + // SAFETY: `setuid(2)` is safe to call with any uid — invalid values return an error. + let ret = unsafe { libc::setuid(uid) }; + if ret == 0 { + Ok(()) + } else { + Error::e_explain( + ErrorType::Custom("Daemon"), + format!( + "drop_privileges_in_parent: setuid({uid}) failed: {}", + std::io::Error::last_os_error() + ), + ) + } +} + +/// Outcome of calling [`daemonize`]. +/// +/// When [`ServerConf::daemon_wait_for_ready`] is `true`, the child process must call +/// [`notify_parent_ready_for_fds`] after bootstrap completes to unblock the parent's wait loop. +pub struct DaemonizeResult { + /// The PID of the original parent process to notify via `SIGUSR1` after bootstrap completes. + /// + /// `Some` when [`ServerConf::daemon_wait_for_ready`] is `true`, `None` otherwise. + pub notify_parent_pid: Option, +} + /// Start a server instance as a daemon. -#[cfg(unix)] -pub fn daemonize(conf: &ServerConf) { - // TODO: customize working dir +/// +/// Both code paths use [`daemonize::Daemonize::execute()`] rather than calling `fork()` directly. +/// `execute()` returns an [`Outcome`] to the caller in each process rather than having the parent +/// exit inside the crate, which gives us the opportunity to run additional logic in the parent +/// before it exits. +/// +/// When [`ServerConf::daemon_wait_for_ready`] is `false` (the default), the parent exits +/// immediately — matching the behavior of `start()`. +/// +/// When `daemon_wait_for_ready` is `true`, the parent registers a `SIGUSR1` handler before +/// forking, then waits (in a sleep loop polling the pid file and the signal flag) for up to +/// [`ServerConf::daemon_ready_timeout_seconds`] (default 600 s) for the grandchild to send +/// `SIGUSR1`. On success the parent exits with code 0. On timeout, or if the daemon process +/// exits before signaling, the parent exits with code 1, causing systemd to abort the reload. +/// +/// Returns a [`DaemonizeResult`] that is only meaningful to the child process. The parent always +/// exits before returning. +pub fn daemonize(conf: &ServerConf) -> DaemonizeResult { + // Capture the parent PID before forking so it can be passed to the grandchild. The + // grandchild sends SIGUSR1 to this PID after bootstrap completes. + let parent_pid = if conf.daemon_wait_for_ready { + Some(process::id()) + } else { + None + }; + + move_old_pid(&conf.pid_file); + + match build_daemonize(conf).execute() { + Outcome::Parent(result) => { + result.unwrap_or_else(|e| panic!("Daemonize failed: {e}")); + } + Outcome::Child(result) => { + result.unwrap_or_else(|e| panic!("Daemonize child setup failed: {e}")); + return DaemonizeResult { + notify_parent_pid: parent_pid, + }; + } + } + + if conf.daemon_wait_for_ready { + // Drop root privileges before waiting so the parent does not linger as root. + if let Err(e) = drop_privileges_in_parent(conf) { + error!("drop_privileges_in_parent failed: {e}"); + + // Exiting the parent process should be fine because if downgrading + // the user's privileges fails here, it will fail in the child and + // the child will exit too + process::exit(1); + } + + let timeout = conf + .daemon_ready_timeout_seconds + .map(|n| Duration::from_secs(n.get())) + .unwrap_or(DEFAULT_DAEMON_READY_TIMEOUT); + + info!( + "Waiting up to {:?} for daemon to signal readiness via SIGUSR1", + timeout + ); + wait_for_ready_or_exit(&conf.pid_file, timeout); + } + + process::exit(0); +} + +/// Build a single-threaded tokio runtime for the parent's signal wait loop. +/// +/// The parent process is short-lived and only needs to wait for a signal and check the pid file. +/// A current-thread runtime avoids spawning worker threads in a process that is about to exit. +fn build_parent_runtime() -> tokio::runtime::Runtime { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("failed to build tokio runtime for parent signal wait") +} + +/// Wait for the daemon grandchild to send `SIGUSR1`, up to `timeout`. +/// +/// Uses a local tokio runtime with [`tokio::signal::unix`] to listen for `SIGUSR1` instead of +/// raw signal handlers and polling loops. The daemon's PID is checked periodically via the pid +/// file — if the process exits before signaling, the parent aborts. +/// +/// Exits the process directly: +/// - exit code 0 if `SIGUSR1` is received (daemon is ready). +/// - exit code 1 if `timeout` elapses (daemon took too long). +/// - exit code 1 if the pid file exists and the process is no longer running. +fn wait_for_ready_or_exit(pid_file: &str, timeout: Duration) { + let rt = build_parent_runtime(); + let pid_file = pid_file.to_owned(); + + rt.block_on(async move { + use tokio::signal::unix::{signal, SignalKind}; + use tokio::time::{interval, timeout as tokio_timeout}; + + let mut sigusr1 = + signal(SignalKind::user_defined1()).expect("failed to register SIGUSR1 listener"); + + let mut liveness_check = interval(LIVENESS_CHECK_INTERVAL); + let mut daemon_pid: Option = None; + + let result = tokio_timeout(timeout, async { + loop { + tokio::select! { + _ = sigusr1.recv() => { + info!("Daemon signaled readiness, parent exiting"); + return; + } + _ = liveness_check.tick() => { + if daemon_pid.is_none() { + daemon_pid = try_read_pid_file(&pid_file); + } + if let Some(pid) = daemon_pid { + if !process_is_running(pid) { + error!( + "Daemon process (pid {pid}) is no longer running \ + before signaling readiness, aborting" + ); + process::exit(1); + } + } + } + } + } + }) + .await; + + if result.is_err() { + error!("Daemon did not signal readiness within {timeout:?}, aborting"); + process::exit(1); + } + }); +} + +/// Notify the parent process that the daemon is ready to serve traffic by sending `SIGUSR1`. +/// +/// Should be called by the daemon process after bootstrap is complete when +/// [`ServerConf::daemon_wait_for_ready`] is `true`. `parent_pid` is the PID of the original +/// process captured before the fork and stored in [`DaemonizeResult::notify_parent_pid`]. +/// +/// `SIGUSR1` sets an atomic flag that the parent's wait loop checks, causing it to exit with +/// code 0 and allowing systemd to proceed with the next step of the service reload. +/// +/// If `kill(2)` returns `EPERM` — which can happen transiently when the child's UID has just +/// been changed by `setuid` and the kernel hasn't yet updated the credential check — the +/// function sleeps for [`NOTIFY_RETRY_INTERVAL`] (100 ms) and retries until `notify_timeout` +/// elapses, at which point it logs an error and returns. Any other error (e.g. `ESRCH`, +/// meaning the parent no longer exists) is logged and the function returns immediately without +/// retrying. +pub fn notify_parent_ready_for_fds(parent_pid: u32, notify_timeout: Duration) { + let parent_pid = parent_pid as libc::pid_t; + info!( + "Sending SIGUSR1 to parent process (pid {}) to signal daemon readiness", + parent_pid + ); + + let start = Instant::now(); + + while start.elapsed() < notify_timeout { + match send_signal(parent_pid, libc::SIGUSR1) { + Ok(()) => return, + Err(SignalError::PermissionDenied) => { + debug!( + "Permission denied sending SIGUSR1 to parent (pid {}), retrying in {:?}", + parent_pid, NOTIFY_RETRY_INTERVAL + ); + thread::sleep(NOTIFY_RETRY_INTERVAL); + } + Err(SignalError::OtherSignalError(errno)) => { + error!( + "Failed to send SIGUSR1 to parent (pid {}): errno {errno}", + parent_pid + ); + return; + } + } + } + + error!( + "Permission denied sending SIGUSR1 to parent (pid {}), giving up after {:?}", + parent_pid, notify_timeout + ); +} + +/// Try to read a PID from `pid_file`. Returns `None` if the file does not exist or cannot be +/// parsed. +fn try_read_pid_file(pid_file: &str) -> Option { + fs::read_to_string(pid_file) + .ok() + .and_then(|c| c.trim().parse().ok()) +} + +/// Returns `true` if a process with `pid` is currently running. +fn process_is_running(pid: libc::pid_t) -> bool { + // Signal 0 does not send a signal; it just checks whether the process exists and whether + // we have permission to signal it. EPERM (no permission) is not possible here because + // drop_privileges_in_parent guarantees the parent has already dropped to the same user as + // the daemon child before this function is called. + send_signal(pid, 0).is_ok() +} + +/// Build a [`Daemonize`] instance configured from `conf`, without calling `start()` or +/// `execute()`. The caller is responsible for driving execution. +fn build_daemonize(conf: &ServerConf) -> Daemonize<()> { let daemonize = Daemonize::new() .umask(0o007) // allow same group to access files but not everyone else .pid_file(&conf.pid_file); + let daemonize = if let Some(working_directory) = conf.working_directory.as_ref() { + daemonize.working_directory(working_directory) + } else { + daemonize + }; + let daemonize = if let Some(error_log) = conf.error_log.as_ref() { let err = OpenOptions::new() .append(true) @@ -82,6 +415,7 @@ pub fn daemonize(conf: &ServerConf) { Some(user) => { let user_cstr = CString::new(user.as_str()).unwrap(); + // SAFETY: `user_cstr` is a valid CString. See `gid_for_username` safety docs. #[cfg(target_os = "macos")] let group_id = unsafe { gid_for_username(&user_cstr).map(|gid| gid as i32) }; #[cfg(target_os = "freebsd")] @@ -92,7 +426,8 @@ pub fn daemonize(conf: &ServerConf) { daemonize .privileged_action(move || { if let Some(gid) = group_id { - // Set the supplemental group privileges for the child process. + // SAFETY: `user_cstr` is a valid CString captured by the closure. + // `initgroups(3)` is safe to call with a valid username and gid. unsafe { libc::initgroups(user_cstr.as_ptr() as *const libc::c_char, gid); } @@ -104,12 +439,8 @@ pub fn daemonize(conf: &ServerConf) { None => daemonize, }; - let daemonize = match conf.group.as_ref() { + match conf.group.as_ref() { Some(group) => daemonize.group(group.as_str()), None => daemonize, - }; - - move_old_pid(&conf.pid_file); - - daemonize.start().unwrap(); // hard crash when fail + } } diff --git a/pingora-core/src/server/mod.rs b/pingora-core/src/server/mod.rs index 93a38bbc..cadd1ac1 100644 --- a/pingora-core/src/server/mod.rs +++ b/pingora-core/src/server/mod.rs @@ -36,7 +36,7 @@ use std::thread; use std::time::SystemTime; #[cfg(unix)] use tokio::signal::unix; -use tokio::sync::{broadcast, watch, Mutex as TokioMutex}; +use tokio::sync::{broadcast, watch}; use tokio::time::{sleep, Duration}; use crate::prelude::background_service; @@ -131,7 +131,7 @@ pub enum ExecutionPhase { /// to shutdown pub type ShutdownWatch = watch::Receiver; #[cfg(unix)] -pub type ListenFds = Arc>; +pub type ListenFds = Arc>; /// The type of shutdown process that has been requested. #[derive(Debug)] @@ -286,43 +286,47 @@ impl Server { .send(ExecutionPhase::GracefulUpgradeTransferringFds) .ok(); - if let Some(fds) = self.listen_fds() { - let fds = fds.lock().await; - info!("Trying to send socks"); - // XXX: this is blocking IO - match fds.send_to_sock(self.configuration.as_ref().upgrade_sock.as_str()) { - Ok(_) => { - info!("listener sockets sent"); - } - Err(e) => { - error!("Unable to send listener sockets to new process: {e}"); - // sentry log error on fd send failure - #[cfg(all(not(debug_assertions), feature = "sentry"))] - sentry::capture_error(&e); + let sent_fds = { + let fds = self.listen_fds(); + let fds = fds.lock(); + if fds.is_empty() { + info!("No socks to send, shutting down."); + false + } else { + info!("Trying to send socks"); + match fds.send_to_sock(self.configuration.as_ref().upgrade_sock.as_str()) { + Ok(_) => { + info!("listener sockets sent"); + } + Err(e) => { + error!("Unable to send listener sockets to new process: {e}"); + #[cfg(all(not(debug_assertions), feature = "sentry"))] + sentry::capture_error(&e); + } } + true } + }; + if sent_fds { self.execution_phase_watch .send(ExecutionPhase::GracefulUpgradeCloseTimeout) .ok(); sleep(Duration::from_secs(CLOSE_TIMEOUT)).await; - info!("Broadcasting graceful shutdown"); - // gracefully exiting - match self.shutdown_watch.send(true) { - Ok(_) => { - info!("Graceful shutdown started!"); - } - Err(e) => { - error!("Graceful shutdown broadcast failed: {e}"); - // switch to fast shutdown - return ShutdownType::Graceful; - } + } + info!("Broadcasting graceful shutdown"); + // gracefully exiting + match self.shutdown_watch.send(true) { + Ok(_) => { + info!("Graceful shutdown started!"); + } + Err(e) => { + error!("Graceful shutdown broadcast failed: {e}"); + // switch to fast shutdown + return ShutdownType::Graceful; } - info!("Broadcast graceful shutdown complete"); - ShutdownType::Graceful - } else { - info!("No socks to send, shutting down."); - ShutdownType::Graceful } + info!("Broadcast graceful shutdown complete"); + ShutdownType::Graceful } } } @@ -374,14 +378,14 @@ impl Server { /// Get the configured file descriptors for listening #[cfg(unix)] - fn listen_fds(&self) -> Option { + fn listen_fds(&self) -> ListenFds { self.bootstrap.lock().get_fds() } #[allow(clippy::too_many_arguments)] fn run_service( mut service: Box, - #[cfg(unix)] fds: Option, + #[cfg(unix)] fds: ListenFds, shutdown: ShutdownWatch, threads: usize, work_stealing: bool, @@ -420,7 +424,7 @@ impl Server { service .start_service( #[cfg(unix)] - fds, + Some(fds), shutdown, listeners_per_fd, ready_notifier, @@ -636,8 +640,13 @@ impl Server { if conf.daemon { info!("Daemonizing the server"); fast_timeout::pause_for_fork(); - daemonize(&self.configuration); + let daemonize_result = daemonize(&self.configuration); fast_timeout::unpause(); + // If daemon_wait_for_ready is enabled, pass the parent PID to bootstrap so it + // can send SIGUSR1 to the parent after bootstrap completes. + if let Some(pid) = daemonize_result.notify_parent_pid { + self.bootstrap.lock().set_notify_parent_pid(pid); + } } #[cfg(windows)] diff --git a/pingora-core/src/server/transfer_fd/mod.rs b/pingora-core/src/server/transfer_fd/mod.rs index 3f852aec..a2fa58cc 100644 --- a/pingora-core/src/server/transfer_fd/mod.rs +++ b/pingora-core/src/server/transfer_fd/mod.rs @@ -50,6 +50,10 @@ impl Fds { self.map.get(bind) } + pub fn is_empty(&self) -> bool { + self.map.is_empty() + } + pub fn serialize(&self) -> (Vec, Vec) { self.map.iter().map(|(key, val)| (key.clone(), val)).unzip() } diff --git a/pingora-core/src/services/listening.rs b/pingora-core/src/services/listening.rs index b6886c21..1810ba0c 100644 --- a/pingora-core/src/services/listening.rs +++ b/pingora-core/src/services/listening.rs @@ -23,7 +23,7 @@ use crate::listeners::tls::TlsSettings; #[cfg(feature = "connection_filter")] use crate::listeners::AcceptAllFilter; use crate::listeners::{ - ConnectionFilter, Listeners, ServerAddress, TcpSocketOptions, TransportStack, + ConnectionFilter, ListenerConfig, Listeners, ServerAddress, TcpSocketOptions, TransportStack, }; use crate::protocols::Stream; #[cfg(unix)] @@ -123,6 +123,11 @@ impl Service { self.listeners.add_tcp(addr); } + /// Add a listening endpoint configured by a [`ListenerConfig`]. + pub fn add_listener(&mut self, endpoint: ListenerConfig) { + self.listeners.add_listener(endpoint); + } + /// Add a TCP listening endpoint with the given [`TcpSocketOptions`]. pub fn add_tcp_with_settings(&mut self, addr: &str, sock_opt: TcpSocketOptions) { self.listeners.add_tcp_with_settings(addr, sock_opt); @@ -309,19 +314,3 @@ impl ServiceTrait for Service { self.threads } } - -#[cfg(feature = "prometheus")] -use crate::apps::prometheus_http_app::PrometheusServer; - -#[cfg(feature = "prometheus")] -impl Service { - /// The Prometheus HTTP server - /// - /// The HTTP server endpoint that reports Prometheus metrics collected in the entire service - pub fn prometheus_http_service() -> Self { - Service::new( - "Prometheus metric HTTP".to_string(), - PrometheusServer::new(), - ) - } -} diff --git a/pingora-core/src/tls/mod.rs b/pingora-core/src/tls/mod.rs deleted file mode 100644 index cc5a8dfd..00000000 --- a/pingora-core/src/tls/mod.rs +++ /dev/null @@ -1,824 +0,0 @@ -// Copyright 2024 Cloudflare, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! This module contains a dummy TLS implementation for the scenarios where real TLS -//! implementations are unavailable. - -macro_rules! impl_display { - ($ty:ty) => { - impl std::fmt::Display for $ty { - fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { - Ok(()) - } - } - }; -} - -macro_rules! impl_deref { - ($from:ty => $to:ty) => { - impl std::ops::Deref for $from { - type Target = $to; - fn deref(&self) -> &$to { - panic!("Not implemented"); - } - } - impl std::ops::DerefMut for $from { - fn deref_mut(&mut self) -> &mut $to { - panic!("Not implemented"); - } - } - }; -} - -pub mod ssl { - use super::error::ErrorStack; - use super::x509::verify::X509VerifyParamRef; - use super::x509::{X509VerifyResult, X509}; - - /// An error returned from an ALPN selection callback. - pub struct AlpnError; - impl AlpnError { - /// Terminate the handshake with a fatal alert. - pub const ALERT_FATAL: AlpnError = Self {}; - - /// Do not select a protocol, but continue the handshake. - pub const NOACK: AlpnError = Self {}; - } - - /// A type which allows for configuration of a client-side TLS session before connection. - pub struct ConnectConfiguration; - impl_deref! {ConnectConfiguration => SslRef} - impl ConnectConfiguration { - /// Configures the use of Server Name Indication (SNI) when connecting. - pub fn set_use_server_name_indication(&mut self, _use_sni: bool) { - panic!("Not implemented"); - } - - /// Configures the use of hostname verification when connecting. - pub fn set_verify_hostname(&mut self, _verify_hostname: bool) { - panic!("Not implemented"); - } - - /// Returns an `Ssl` configured to connect to the provided domain. - pub fn into_ssl(self, _domain: &str) -> Result { - panic!("Not implemented"); - } - - /// Like `SslContextBuilder::set_verify`. - pub fn set_verify(&mut self, _mode: SslVerifyMode) { - panic!("Not implemented"); - } - - /// Like `SslContextBuilder::set_alpn_protos`. - pub fn set_alpn_protos(&mut self, _protocols: &[u8]) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Returns a mutable reference to the X509 verification configuration. - pub fn param_mut(&mut self) -> &mut X509VerifyParamRef { - panic!("Not implemented"); - } - } - - /// An SSL error. - #[derive(Debug)] - pub struct Error; - impl_display!(Error); - impl Error { - pub fn code(&self) -> ErrorCode { - panic!("Not implemented"); - } - } - - /// An error code returned from SSL functions. - #[derive(PartialEq)] - pub struct ErrorCode(i32); - impl ErrorCode { - /// An error occurred in the SSL library. - pub const SSL: ErrorCode = Self(0); - } - - /// An identifier of a session name type. - pub struct NameType; - impl NameType { - pub const HOST_NAME: NameType = Self {}; - } - - /// The state of an SSL/TLS session. - pub struct Ssl; - impl Ssl { - /// Creates a new `Ssl`. - pub fn new(_ctx: &SslContextRef) -> Result { - panic!("Not implemented"); - } - } - impl_deref! {Ssl => SslRef} - - /// A type which wraps server-side streams in a TLS session. - pub struct SslAcceptor; - impl SslAcceptor { - /// Creates a new builder configured to connect to non-legacy clients. This should - /// generally be considered a reasonable default choice. - pub fn mozilla_intermediate_v5( - _method: SslMethod, - ) -> Result { - panic!("Not implemented"); - } - } - - /// A builder for `SslAcceptor`s. - pub struct SslAcceptorBuilder; - impl SslAcceptorBuilder { - /// Consumes the builder, returning a `SslAcceptor`. - pub fn build(self) -> SslAcceptor { - panic!("Not implemented"); - } - - /// Sets the callback used by a server to select a protocol for Application Layer Protocol - /// Negotiation (ALPN). - pub fn set_alpn_select_callback(&mut self, _callback: F) - where - F: for<'a> Fn(&mut SslRef, &'a [u8]) -> Result<&'a [u8], AlpnError> - + 'static - + Sync - + Send, - { - panic!("Not implemented"); - } - - /// Loads a certificate chain from a file. - pub fn set_certificate_chain_file>( - &mut self, - _file: P, - ) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Loads the private key from a file. - pub fn set_private_key_file>( - &mut self, - _file: P, - _file_type: SslFiletype, - ) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Sets the maximum supported protocol version. - pub fn set_max_proto_version( - &mut self, - _version: Option, - ) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Sets the list of supported ciphers for protocols before TLSv1.3. - pub fn set_cipher_list(&mut self, _cipher_list: &str) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Sets the list of supported cipher suites for TLSv1.3. - pub fn set_ciphersuites(&mut self, _ciphersuites: &str) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Sets the minimum supported protocol version. - pub fn set_min_proto_version( - &mut self, - _version: Option, - ) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - } - - /// Reference to an [`SslCipher`]. - pub struct SslCipherRef; - impl SslCipherRef { - /// Returns the name of the cipher. - pub fn name(&self) -> &'static str { - panic!("Not implemented"); - } - } - - /// A type which wraps client-side streams in a TLS session. - pub struct SslConnector; - impl SslConnector { - /// Creates a new builder for TLS connections. - pub fn builder(_method: SslMethod) -> Result { - panic!("Not implemented"); - } - - /// Returns a structure allowing for configuration of a single TLS session before connection. - pub fn configure(&self) -> Result { - panic!("Not implemented"); - } - - /// Returns a shared reference to the inner raw `SslContext`. - pub fn context(&self) -> &SslContextRef { - panic!("Not implemented"); - } - } - - /// A builder for `SslConnector`s. - pub struct SslConnectorBuilder; - impl SslConnectorBuilder { - /// Consumes the builder, returning an `SslConnector`. - pub fn build(self) -> SslConnector { - panic!("Not implemented"); - } - - /// Sets the list of supported ciphers for protocols before TLSv1.3. - pub fn set_cipher_list(&mut self, _cipher_list: &str) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Sets the context’s supported signature algorithms. - pub fn set_sigalgs_list(&mut self, _sigalgs: &str) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Sets the minimum supported protocol version. - pub fn set_min_proto_version( - &mut self, - _version: Option, - ) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Sets the maximum supported protocol version. - pub fn set_max_proto_version( - &mut self, - _version: Option, - ) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Use the default locations of trusted certificates for verification. - pub fn set_default_verify_paths(&mut self) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Loads trusted root certificates from a file. - pub fn set_ca_file>( - &mut self, - _file: P, - ) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Loads a leaf certificate from a file. - pub fn set_certificate_file>( - &mut self, - _file: P, - _file_type: SslFiletype, - ) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Loads the private key from a file. - pub fn set_private_key_file>( - &mut self, - _file: P, - _file_type: SslFiletype, - ) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Sets the TLS key logging callback. - pub fn set_keylog_callback(&mut self, _callback: F) - where - F: Fn(&SslRef, &str) + 'static + Sync + Send, - { - panic!("Not implemented"); - } - } - - /// A context object for TLS streams. - pub struct SslContext; - impl SslContext { - /// Creates a new builder object for an `SslContext`. - pub fn builder(_method: SslMethod) -> Result { - panic!("Not implemented"); - } - } - impl_deref! {SslContext => SslContextRef} - - /// A builder for `SslContext`s. - pub struct SslContextBuilder; - impl SslContextBuilder { - /// Consumes the builder, returning a new `SslContext`. - pub fn build(self) -> SslContext { - panic!("Not implemented"); - } - } - - /// Reference to [`SslContext`] - pub struct SslContextRef; - - /// An identifier of the format of a certificate or key file. - pub struct SslFiletype; - impl SslFiletype { - /// The PEM format. - pub const PEM: SslFiletype = Self {}; - } - - /// A type specifying the kind of protocol an `SslContext`` will speak. - pub struct SslMethod; - impl SslMethod { - /// Support all versions of the TLS protocol. - pub fn tls() -> SslMethod { - panic!("Not implemented"); - } - } - - /// Reference to an [`Ssl`]. - pub struct SslRef; - impl SslRef { - /// Like [`SslContextBuilder::set_verify`]. - pub fn set_verify(&mut self, _mode: SslVerifyMode) { - panic!("Not implemented"); - } - - /// Returns the current cipher if the session is active. - pub fn current_cipher(&self) -> Option<&SslCipherRef> { - panic!("Not implemented"); - } - - /// Sets the host name to be sent to the server for Server Name Indication (SNI). - pub fn set_hostname(&mut self, _hostname: &str) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Returns the peer’s certificate, if present. - pub fn peer_certificate(&self) -> Option { - panic!("Not implemented"); - } - - /// Returns the certificate verification result. - pub fn verify_result(&self) -> X509VerifyResult { - panic!("Not implemented"); - } - - /// Returns a string describing the protocol version of the session. - pub fn version_str(&self) -> &'static str { - panic!("Not implemented"); - } - - /// Returns the protocol selected via Application Layer Protocol Negotiation (ALPN). - pub fn selected_alpn_protocol(&self) -> Option<&[u8]> { - panic!("Not implemented"); - } - - /// Returns the servername sent by the client via Server Name Indication (SNI). - pub fn servername(&self, _type_: NameType) -> Option<&str> { - panic!("Not implemented"); - } - } - - /// Options controlling the behavior of certificate verification. - pub struct SslVerifyMode; - impl SslVerifyMode { - /// Verifies that the peer’s certificate is trusted. - pub const PEER: Self = Self {}; - - /// Disables verification of the peer’s certificate. - pub const NONE: Self = Self {}; - } - - /// An SSL/TLS protocol version. - pub struct SslVersion; - impl SslVersion { - /// TLSv1.0 - pub const TLS1: SslVersion = Self {}; - - /// TLSv1.2 - pub const TLS1_2: SslVersion = Self {}; - - /// TLSv1.3 - pub const TLS1_3: SslVersion = Self {}; - } - - /// A standard implementation of protocol selection for Application Layer Protocol Negotiation - /// (ALPN). - pub fn select_next_proto<'a>(_server: &[u8], _client: &'a [u8]) -> Option<&'a [u8]> { - panic!("Not implemented"); - } -} - -pub mod ssl_sys { - pub const X509_V_OK: i32 = 0; - pub const X509_V_ERR_INVALID_CALL: i32 = 69; -} - -pub mod error { - use super::ssl::Error; - - /// Collection of [`Errors`] from OpenSSL. - #[derive(Debug)] - pub struct ErrorStack; - impl_display!(ErrorStack); - impl std::error::Error for ErrorStack {} - impl ErrorStack { - /// Returns the contents of the OpenSSL error stack. - pub fn get() -> ErrorStack { - panic!("Not implemented"); - } - - /// Returns the errors in the stack. - pub fn errors(&self) -> &[Error] { - panic!("Not implemented"); - } - } -} - -pub mod x509 { - use super::asn1::{Asn1IntegerRef, Asn1StringRef, Asn1TimeRef}; - use super::error::ErrorStack; - use super::hash::{DigestBytes, MessageDigest}; - use super::nid::Nid; - - /// An `X509` public key certificate. - #[derive(Debug, Clone)] - pub struct X509; - impl_deref! {X509 => X509Ref} - impl X509 { - /// Deserializes a PEM-encoded X509 structure. - pub fn from_pem(_pem: &[u8]) -> Result { - panic!("Not implemented"); - } - } - - /// A type to destructure and examine an `X509Name`. - pub struct X509NameEntries<'a> { - marker: std::marker::PhantomData<&'a ()>, - } - impl<'a> Iterator for X509NameEntries<'a> { - type Item = &'a X509NameEntryRef; - fn next(&mut self) -> Option<&'a X509NameEntryRef> { - panic!("Not implemented"); - } - } - - /// Reference to `X509NameEntry`. - pub struct X509NameEntryRef; - impl X509NameEntryRef { - pub fn data(&self) -> &Asn1StringRef { - panic!("Not implemented"); - } - } - - /// Reference to `X509Name`. - pub struct X509NameRef; - impl X509NameRef { - /// Returns the name entries by the nid. - pub fn entries_by_nid(&self, _nid: Nid) -> X509NameEntries<'_> { - panic!("Not implemented"); - } - } - - /// Reference to `X509`. - pub struct X509Ref; - impl X509Ref { - /// Returns this certificate’s subject name. - pub fn subject_name(&self) -> &X509NameRef { - panic!("Not implemented"); - } - - /// Returns a digest of the DER representation of the certificate. - pub fn digest(&self, _hash_type: MessageDigest) -> Result { - panic!("Not implemented"); - } - - /// Returns the certificate’s Not After validity period. - pub fn not_after(&self) -> &Asn1TimeRef { - panic!("Not implemented"); - } - - /// Returns this certificate’s serial number. - pub fn serial_number(&self) -> &Asn1IntegerRef { - panic!("Not implemented"); - } - } - - /// The result of peer certificate verification. - pub struct X509VerifyResult; - impl X509VerifyResult { - /// Return the integer representation of an `X509VerifyResult`. - pub fn as_raw(&self) -> i32 { - panic!("Not implemented"); - } - } - - pub mod store { - use super::super::error::ErrorStack; - use super::X509; - - /// A builder type used to construct an `X509Store`. - pub struct X509StoreBuilder; - impl X509StoreBuilder { - /// Returns a builder for a certificate store.. - pub fn new() -> Result { - panic!("Not implemented"); - } - - /// Constructs the `X509Store`. - pub fn build(self) -> X509Store { - panic!("Not implemented"); - } - - /// Adds a certificate to the certificate store. - pub fn add_cert(&mut self, _cert: X509) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - } - - /// A certificate store to hold trusted X509 certificates. - pub struct X509Store; - impl_deref! {X509Store => X509StoreRef} - - /// Reference to an `X509Store`. - pub struct X509StoreRef; - } - - pub mod verify { - /// Reference to `X509VerifyParam`. - pub struct X509VerifyParamRef; - } -} - -pub mod nid { - /// A numerical identifier for an OpenSSL object. - pub struct Nid; - impl Nid { - pub const COMMONNAME: Nid = Self {}; - pub const ORGANIZATIONNAME: Nid = Self {}; - pub const ORGANIZATIONALUNITNAME: Nid = Self {}; - } -} - -pub mod pkey { - use super::error::ErrorStack; - - /// A public or private key. - #[derive(Clone)] - pub struct PKey { - marker: std::marker::PhantomData, - } - impl std::ops::Deref for PKey { - type Target = PKeyRef; - fn deref(&self) -> &PKeyRef { - panic!("Not implemented"); - } - } - impl std::ops::DerefMut for PKey { - fn deref_mut(&mut self) -> &mut PKeyRef { - panic!("Not implemented"); - } - } - impl PKey { - pub fn private_key_from_pem(_pem: &[u8]) -> Result, ErrorStack> { - panic!("Not implemented"); - } - } - - /// Reference to `PKey`. - pub struct PKeyRef { - marker: std::marker::PhantomData, - } - - /// A tag type indicating that a key has private components. - #[derive(Clone)] - pub enum Private {} - unsafe impl HasPrivate for Private {} - - /// A trait indicating that a key has private components. - pub unsafe trait HasPrivate {} -} - -pub mod hash { - /// A message digest algorithm. - pub struct MessageDigest; - impl MessageDigest { - pub fn sha256() -> MessageDigest { - panic!("Not implemented"); - } - } - - /// The resulting bytes of a digest. - pub struct DigestBytes; - impl AsRef<[u8]> for DigestBytes { - fn as_ref(&self) -> &[u8] { - panic!("Not implemented"); - } - } -} - -pub mod asn1 { - use super::bn::BigNum; - use super::error::ErrorStack; - - /// A reference to an `Asn1Integer`. - pub struct Asn1IntegerRef; - impl Asn1IntegerRef { - /// Converts the integer to a `BigNum`. - pub fn to_bn(&self) -> Result { - panic!("Not implemented"); - } - } - - /// A reference to an `Asn1String`. - pub struct Asn1StringRef; - impl Asn1StringRef { - pub fn as_utf8(&self) -> Result<&str, ErrorStack> { - panic!("Not implemented"); - } - } - - /// Reference to an `Asn1Time` - pub struct Asn1TimeRef; - impl_display! {Asn1TimeRef} -} - -pub mod bn { - use super::error::ErrorStack; - - /// Dynamically sized large number implementation - pub struct BigNum; - impl BigNum { - /// Returns a hexadecimal string representation of `self`. - pub fn to_hex_str(&self) -> Result<&str, ErrorStack> { - panic!("Not implemented"); - } - } -} - -pub mod ext { - use super::error::ErrorStack; - use super::pkey::{HasPrivate, PKeyRef}; - use super::ssl::{Ssl, SslAcceptor, SslRef}; - use super::x509::store::X509StoreRef; - use super::x509::verify::X509VerifyParamRef; - use super::x509::X509Ref; - - /// Add name as an additional reference identifier that can match the peer's certificate - pub fn add_host(_verify_param: &mut X509VerifyParamRef, _host: &str) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Set the verify cert store of `_ssl` - pub fn ssl_set_verify_cert_store( - _ssl: &mut SslRef, - _cert_store: &X509StoreRef, - ) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Load the certificate into `_ssl` - pub fn ssl_use_certificate(_ssl: &mut SslRef, _cert: &X509Ref) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Load the private key into `_ssl` - pub fn ssl_use_private_key(_ssl: &mut SslRef, _key: &PKeyRef) -> Result<(), ErrorStack> - where - T: HasPrivate, - { - panic!("Not implemented"); - } - - /// Clear the error stack - pub fn clear_error_stack() {} - - /// Create a new [Ssl] from &[SslAcceptor] - pub fn ssl_from_acceptor(_acceptor: &SslAcceptor) -> Result { - panic!("Not implemented"); - } - - /// Suspend the TLS handshake when a certificate is needed. - pub fn suspend_when_need_ssl_cert(_ssl: &mut SslRef) { - panic!("Not implemented"); - } - - /// Unblock a TLS handshake after the certificate is set. - pub fn unblock_ssl_cert(_ssl: &mut SslRef) { - panic!("Not implemented"); - } - - /// Whether the TLS error is SSL_ERROR_WANT_X509_LOOKUP - pub fn is_suspended_for_cert(_error: &super::ssl::Error) -> bool { - panic!("Not implemented"); - } - - /// Add the certificate into the cert chain of `_ssl` - pub fn ssl_add_chain_cert(_ssl: &mut SslRef, _cert: &X509Ref) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Set renegotiation - pub fn ssl_set_renegotiate_mode_freely(_ssl: &mut SslRef) {} - - /// Set the curves/groups of `_ssl` - pub fn ssl_set_groups_list(_ssl: &mut SslRef, _groups: &str) -> Result<(), ErrorStack> { - panic!("Not implemented"); - } - - /// Sets whether a second keyshare to be sent in client hello when PQ is used. - pub fn ssl_use_second_key_share(_ssl: &mut SslRef, _enabled: bool) {} - - /// Get a mutable SslRef ouf of SslRef, which is a missing functionality even when holding &mut SslStream - /// # Safety - pub unsafe fn ssl_mut(_ssl: &SslRef) -> &mut SslRef { - panic!("Not implemented"); - } -} - -pub mod tokio_ssl { - use std::pin::Pin; - use std::task::{Context, Poll}; - use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; - - use super::error::ErrorStack; - use super::ssl::{Error, Ssl, SslRef}; - - /// A TLS session over a stream. - #[derive(Debug)] - pub struct SslStream { - marker: std::marker::PhantomData, - } - impl SslStream { - /// Creates a new `SslStream`. - pub fn new(_ssl: Ssl, _stream: S) -> Result { - panic!("Not implemented"); - } - - /// Initiates a client-side TLS handshake. - pub async fn connect(self: Pin<&mut Self>) -> Result<(), Error> { - panic!("Not implemented"); - } - - /// Initiates a server-side TLS handshake. - pub async fn accept(self: Pin<&mut Self>) -> Result<(), Error> { - panic!("Not implemented"); - } - - /// Returns a shared reference to the `Ssl` object associated with this stream. - pub fn ssl(&self) -> &SslRef { - panic!("Not implemented"); - } - - /// Returns a shared reference to the underlying stream. - pub fn get_ref(&self) -> &S { - panic!("Not implemented"); - } - - /// Returns a mutable reference to the underlying stream. - pub fn get_mut(&mut self) -> &mut S { - panic!("Not implemented"); - } - } - impl AsyncRead for SslStream - where - S: AsyncRead + AsyncWrite, - { - fn poll_read( - self: Pin<&mut Self>, - _ctx: &mut Context<'_>, - _buf: &mut ReadBuf<'_>, - ) -> Poll> { - panic!("Not implemented"); - } - } - impl AsyncWrite for SslStream - where - S: AsyncRead + AsyncWrite, - { - fn poll_write( - self: Pin<&mut Self>, - _ctx: &mut Context<'_>, - _buf: &[u8], - ) -> Poll> { - panic!("Not implemented"); - } - - fn poll_flush(self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll> { - panic!("Not implemented"); - } - - fn poll_shutdown( - self: Pin<&mut Self>, - _ctx: &mut Context<'_>, - ) -> Poll> { - panic!("Not implemented"); - } - } -} diff --git a/pingora-core/src/upstreams/peer.rs b/pingora-core/src/upstreams/peer.rs index c9ae0a66..c7f5e40c 100644 --- a/pingora-core/src/upstreams/peer.rs +++ b/pingora-core/src/upstreams/peer.rs @@ -33,6 +33,7 @@ use pingora_error::{ }; #[cfg(feature = "s2n")] use pingora_s2n::S2NPolicy; +use std::borrow::Cow; use std::collections::BTreeMap; use std::fmt::{Display, Formatter, Result as FmtResult}; use std::hash::{Hash, Hasher}; @@ -431,8 +432,14 @@ pub struct PeerOptions { pub s2n_security_policy: Option, #[cfg(feature = "s2n")] pub max_blinding_delay: Option, - // how many concurrent h2 stream are allowed in the same connection + /// How many concurrent h2 streams are allowed in the same connection. pub max_h2_streams: usize, + /// Initial per-stream H2 receive window size in bytes. + /// If `None`, the default of 8MB is used. + pub h2_stream_window_size: Option, + /// Initial connection-level H2 receive window size in bytes. + /// If `None`, the default of 8MB is used. + pub h2_connection_window_size: Option, /// Allow invalid Content-Length in HTTP/1 responses (non-RFC compliant). /// /// When enabled, invalid Content-Length responses are treated as close-delimited responses. @@ -441,16 +448,16 @@ pub struct PeerOptions { /// It exists primarily for compatibility with legacy servers that send malformed headers. pub allow_h1_response_invalid_content_length: bool, pub extra_proxy_headers: BTreeMap>, - // The list of curve the tls connection should advertise - // if `None`, the default curves will be used - pub curves: Option<&'static str>, - // see ssl_use_second_key_share + /// The list of curves the tls connection should advertise + /// if `None`, the default curves will be used + pub curves: Option>, + /// see ssl_use_second_key_share pub second_keyshare: bool, - // whether to enable TCP fast open + /// whether to enable TCP fast open pub tcp_fast_open: bool, - // use Arc because Clone is required but not allowed in trait object + /// use Arc because Clone is required but not allowed in trait object pub tracer: Option, - // A custom L4 connector to use to establish new L4 connections + /// A custom L4 connector to use to establish new L4 connections pub custom_l4: Option>, #[derivative(Debug = "ignore")] pub upstream_tcp_sock_tweak_hook: @@ -494,6 +501,8 @@ impl PeerOptions { #[cfg(feature = "s2n")] max_blinding_delay: None, max_h2_streams: 1, + h2_stream_window_size: None, + h2_connection_window_size: None, allow_h1_response_invalid_content_length: false, extra_proxy_headers: BTreeMap::new(), curves: None, @@ -685,6 +694,12 @@ impl Hash for HttpPeer { self.group_key.hash(state); // max h2 stream settings self.options.max_h2_streams.hash(state); + // h2_stream_window_size and h2_connection_window_size are intentionally excluded + // from the reuse hash for now. These are per-connection settings applied at handshake + // time and may be revisited alongside other h2 settings that could be dynamically + // adjusted over the lifetime of a connection. + self.options.curves.hash(state); + self.options.second_keyshare.hash(state); } } diff --git a/pingora-core/src/utils/tls/rustls.rs b/pingora-core/src/utils/tls/rustls.rs index 429b3724..d4e4e5c9 100644 --- a/pingora-core/src/utils/tls/rustls.rs +++ b/pingora-core/src/utils/tls/rustls.rs @@ -101,17 +101,17 @@ pub struct CertKey { certificates: Vec, } -#[self_referencing] +#[self_referencing(pub_extras)] #[derive(Debug)] pub struct WrappedX509 { - raw_cert: Vec, + pub raw_cert: Vec, #[borrows(raw_cert)] #[covariant] - cert: X509Certificate<'this>, + pub cert: X509Certificate<'this>, } -fn parse_x509(raw_cert: &C) -> X509Certificate<'_> +pub fn parse_x509(raw_cert: &C) -> X509Certificate<'_> where C: AsRef<[u8]>, { diff --git a/pingora-core/tests/bootstrap_as_a_service.rs b/pingora-core/tests/bootstrap_as_a_service.rs new file mode 100644 index 00000000..fa88ff20 --- /dev/null +++ b/pingora-core/tests/bootstrap_as_a_service.rs @@ -0,0 +1,136 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Integration tests for `bootstrap_as_a_service`. +//! +//! Verifies that when `bootstrap_as_a_service()` dependencies are declared, the +//! `BootstrapComplete` execution phase is not reached until all dependency services have +//! finished their initialization work. + +use async_trait::async_trait; +use pingora_core::server::ShutdownWatch; +use pingora_core::server::{configuration::ServerConf, ExecutionPhase, RunArgs, Server}; +use pingora_core::services::background::{background_service, BackgroundService}; +use pingora_core::services::ServiceReadyNotifier; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +/// A background service that sets a flag when it completes, after an optional delay. +/// +/// Signals readiness only after completing its initialization work so that dependent +/// services (like `BootstrapService`) cannot start until this service is truly done. +struct TrackableService { + delay: Duration, + completed: Arc, +} + +#[async_trait] +impl BackgroundService for TrackableService { + async fn start_with_ready_notifier( + &self, + _shutdown: ShutdownWatch, + ready_notifier: ServiceReadyNotifier, + ) { + if !self.delay.is_zero() { + tokio::time::sleep(self.delay).await; + } + self.completed.store(true, Ordering::SeqCst); + // Signal readiness only after work is done — this is what the dependency + // mechanism waits on before allowing BootstrapService to proceed. + ready_notifier.notify_ready(); + } +} + +/// Verifies that `bootstrap_as_a_service` does not reach `BootstrapComplete` until all +/// declared dependency services have finished their initialization work. +#[test] +fn test_bootstrap_waits_for_dependencies() { + let conf = ServerConf { + grace_period_seconds: Some(1), + graceful_shutdown_timeout_seconds: Some(1), + ..Default::default() + }; + + let mut server = Server::new_with_opt_and_conf(None, conf); + let mut phase = server.watch_execution_phase(); + + // Two dependency services with delays. The second (150 ms) sets the pace. + let dep1_done = Arc::new(AtomicBool::new(false)); + let dep2_done = Arc::new(AtomicBool::new(false)); + + let dep1_handle = server.add_service(background_service( + "dep1", + TrackableService { + delay: Duration::from_millis(50), + completed: dep1_done.clone(), + }, + )); + let dep2_handle = server.add_service(background_service( + "dep2", + TrackableService { + delay: Duration::from_millis(150), + completed: dep2_done.clone(), + }, + )); + + // BootstrapService must not reach BootstrapComplete until dep1 and dep2 are done. + let bootstrap_handle = server.bootstrap_as_a_service(); + bootstrap_handle.add_dependencies([&dep1_handle, &dep2_handle]); + + // When using bootstrap_as_a_service, do NOT call server.bootstrap() separately — + // the BootstrapService runs as a background service during run(), and emits + // BootstrapComplete only after all its declared dependencies are ready. + let _join = std::thread::spawn(move || { + server.run(RunArgs::default()); + }); + + let mut received_bootstrap = false; + let mut received_bootstrap_complete = false; + + // Collect phases until BootstrapComplete is seen. Running may arrive + // before or after Bootstrap/BootstrapComplete since main_loop starts + // concurrently with the service runtimes. + loop { + match phase.blocking_recv() { + Ok(ExecutionPhase::Bootstrap) => { + received_bootstrap = true; + } + Ok(ExecutionPhase::BootstrapComplete) => { + // Both dependencies must have set their flags before bootstrap completes. + assert!( + dep1_done.load(Ordering::SeqCst), + "dep1 should be done before BootstrapComplete" + ); + assert!( + dep2_done.load(Ordering::SeqCst), + "dep2 should be done before BootstrapComplete" + ); + received_bootstrap_complete = true; + break; + } + Ok(_) => {} + Err(_) => break, + } + } + + assert!(received_bootstrap, "should have seen Bootstrap phase"); + assert!( + received_bootstrap_complete, + "should have seen BootstrapComplete phase" + ); + + // Shut down cleanly. + std::process::exit(0); +} diff --git a/pingora-core/tests/test_basic.rs b/pingora-core/tests/test_basic.rs index 445d75b9..ae6ee810 100644 --- a/pingora-core/tests/test_basic.rs +++ b/pingora-core/tests/test_basic.rs @@ -14,6 +14,8 @@ mod utils; +#[cfg(all(unix, feature = "any_tls"))] +use hyper_util::client::legacy::Client; #[cfg(all(unix, feature = "any_tls"))] use hyperlocal::{UnixClientExt, Uri}; @@ -55,7 +57,8 @@ async fn test_https_http2() { async fn test_uds() { utils::init(); let url = Uri::new("/tmp/echo.sock", "/").into(); - let client = hyper::Client::unix(); + let client: Client> = + Client::unix(); let res = client.get(url).await.unwrap(); assert_eq!(res.status(), reqwest::StatusCode::OK); diff --git a/pingora-http/src/lib.rs b/pingora-http/src/lib.rs index 954be81b..f66b4be8 100644 --- a/pingora-http/src/lib.rs +++ b/pingora-http/src/lib.rs @@ -256,22 +256,55 @@ impl RequestHeader { /// Generally prefer [Self::set_uri()] to modify the header's URI if able. /// /// This API is to allow supporting non UTF-8 cases. + /// + /// Accepts every URI form HTTP/1.1 request-lines can carry + /// (RFC 9112 § 3.2): origin-form (`/path?query`), absolute-form + /// (`http://host/path`), authority-form (`host:port`, used by + /// CONNECT per RFC 9110 § 9.3.6), and asterisk-form (`*`, used by + /// OPTIONS). + /// + /// Resolution order for the structured `base.uri`: + /// 1. `Uri::builder().path_and_query(...)` — origin-form, the hot + /// path for ~all traffic. + /// 2. `Uri::try_from(...)` — authority / absolute / asterisk forms + /// (CONNECT request-lines land here). + /// 3. Neither parses (e.g. a bare `\` path: valid UTF-8 but not a + /// structurally valid URI — `http >= 1.4` rejects paths that + /// don't start with `/`, where `http <= 1.3` accepted them). + /// Rather than reject the request, preserve the raw bytes in + /// `raw_path_fallback` (the same mechanism the non-UTF8 branch + /// already uses) and set a `/` sentinel on `base.uri` so the + /// structured URI stays valid. `raw_path()` and the H1 wire + /// serializer read `raw_path_fallback` first, so the original + /// bytes still round-trip verbatim. pub fn set_raw_path(&mut self, path: &[u8]) -> Result<()> { - if let Ok(p) = std::str::from_utf8(path) { - let uri = Uri::builder() + fn parse(p: &str) -> Option { + // origin-form first, then authority/absolute/asterisk forms. + Uri::builder() .path_and_query(p) .build() - .explain_err(InvalidHTTPHeader, |_| format!("invalid uri {}", p))?; - self.base.uri = uri; - // keep raw_path empty, no need to store twice + .ok() + .or_else(|| Uri::try_from(p).ok()) + } + // Sentinel for paths that are valid UTF-8 (or lossy-decodable) + // but not structurally valid URIs; the real bytes live in + // raw_path_fallback. + const SENTINEL: &str = "/"; + if let Ok(p) = std::str::from_utf8(path) { + match parse(p) { + Some(uri) => { + self.base.uri = uri; + // keep raw_path empty, no need to store twice + } + None => { + self.base.uri = Uri::from_static(SENTINEL); + self.raw_path_fallback = path.to_vec(); + } + } } else { - // put a valid utf-8 path into base for read only access + // Non-UTF8: keep a readable base URI, preserve raw bytes. let lossy_str = String::from_utf8_lossy(path); - let uri = Uri::builder() - .path_and_query(lossy_str.as_ref()) - .build() - .explain_err(InvalidHTTPHeader, |_| format!("invalid uri {}", lossy_str))?; - self.base.uri = uri; + self.base.uri = parse(lossy_str.as_ref()).unwrap_or_else(|| Uri::from_static(SENTINEL)); self.raw_path_fallback = path.to_vec(); } Ok(()) @@ -294,18 +327,21 @@ impl RequestHeader { /// Return the request path in its raw format /// /// Non-UTF8 is supported. + /// + /// For authority-form request URIs (RFC 9110 § 9.3.6 / CONNECT) + /// `Uri::path_and_query()` returns `None`; we fall back to the + /// raw authority bytes (e.g. `pingora.org:443`). For asterisk-form + /// (OPTIONS *) both `path_and_query` and `authority` are absent + /// — return an empty slice rather than panic. pub fn raw_path(&self) -> &[u8] { if !self.raw_path_fallback.is_empty() { &self.raw_path_fallback + } else if let Some(pq) = self.base.uri.path_and_query() { + pq.as_str().as_bytes() + } else if let Some(auth) = self.base.uri.authority() { + auth.as_str().as_bytes() } else { - // Url should always be set - self.base - .uri - .path_and_query() - .as_ref() - .unwrap() - .as_str() - .as_bytes() + &[] } } @@ -816,9 +852,10 @@ mod tests { let mut buf: Vec = vec![]; req.header_to_h1_wire(&mut buf); assert_eq!(buf, b"foo: Bar\r\n"); - req.case_header_iter().for_each(|(_, _)| { - unreachable!("request has no case"); - }); + assert!( + req.case_header_iter().next().is_none(), + "request has no case" + ); let mut resp = ResponseHeader::new_no_case(None); resp.insert_header("foo", "bar").unwrap(); @@ -826,9 +863,10 @@ mod tests { let mut buf: Vec = vec![]; resp.header_to_h1_wire(&mut buf); assert_eq!(buf, b"foo: Bar\r\n"); - resp.case_header_iter().for_each(|(_, _)| { - unreachable!("response has no case"); - }); + assert!( + resp.case_header_iter().next().is_none(), + "response has no case" + ); } #[test] diff --git a/pingora-lru/benches/bench_lru.rs b/pingora-lru/benches/bench_lru.rs index c0bdc776..02dadec8 100644 --- a/pingora-lru/benches/bench_lru.rs +++ b/pingora-lru/benches/bench_lru.rs @@ -12,137 +12,237 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Benchmark for `Lru::promote()` vs `Lru::promote_top_n()`. +//! +//! Tests both small (original) and production-scale LRU sizes to show how +//! the `promote_top_n` optimization behaves at different scales. +//! +//! Run with: `cargo bench -p pingora-lru --bench bench_lru` +//! +//! ## Results (Apple M3 Max, 2026-04-03) +//! +//! Benchmark tiers simulate production-level data sizes (items/shard +//! ranging from ~100 to ~500K) with both uniform-hot and heavy-hitter +//! access patterns. +//! +//! ### 8-threaded — uniform hot set (10% of items are 100x hotter) +//! +//! | Items/shard | promote | top_n(0) | top_n(3) | top_n(10) | top_n(50) | top_n(100) | +//! |-------------|------------|----------|----------|-----------|-----------|------------| +//! | 10 (orig) | 366ns | 476ns | 271ns | **164ns** | 164ns | 164ns | +//! | 100K (typ) | **457ns** | 480ns | 437ns | 520ns | 1227ns | 2394ns | +//! +//! ### 8-threaded — heavy hitters (10 or 100 items are 10,000x hotter) +//! +//! | Items/shard | promote | top_n(0) | top_n(3) | top_n(10) | top_n(50) | top_n(100) | +//! |-----------------|------------|----------|----------|-----------|-----------|------------| +//! | 100K, 10 hot | **649ns** | 688ns | 652ns | 773ns | 1811ns | 3534ns | +//! | 100K, 100 hot | **607ns** | 632ns | 607ns | 716ns | 1493ns | 2759ns | +//! +//! ### Single-threaded — uniform hot set (10% of items are 100x hotter) +//! +//! | Items/shard | promote | top_n(0) | top_n(3) | top_n(10) | top_n(50) | top_n(100) | +//! |-------------|------------|----------|----------|-----------|-----------|------------| +//! | 10 (orig) | 22ns | 20ns | 29ns | **30ns** | 30ns | 30ns | +//! | 100K (typ) | **297ns** | 306ns | 314ns | 332ns | 663ns | 1092ns | +//! +//! **Conclusions**: +//! +//! - `promote_top_n(0)` is strictly worse than `promote()` — it takes a +//! wasted read lock before falling through to the write lock every time. +//! +//! - `promote_top_n(n)` for n > 0 only wins at the original small scale +//! (10 items/shard) where the threshold covers the entire shard. +//! +//! - Even with heavy-hitter patterns (10 items at 10,000x weight), +//! `promote()` ties or wins at production scale. With 10 hot items +//! across 32 shards, most shards have 0-1 hot items, so the read-lock +//! scan is wasted on the majority of cold-item accesses. +//! +//! - At production scale (~100K+ items/shard), plain `promote()` is fastest +//! regardless of access pattern. + use rand::distributions::WeightedIndex; use rand::prelude::*; -use std::sync::Arc; +use std::sync::{Arc, Barrier}; use std::thread; use std::time::Instant; -// Non-uniform distributions, 100 items, 10 of them are 100x more likely to appear -const WEIGHTS: &[usize] = &[ - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 100, 100, 100, - 100, 100, 100, 100, 100, 100, 100, -]; - const ITERATIONS: usize = 5_000_000; const THREADS: usize = 8; -fn main() { - let lru = parking_lot::Mutex::new(lru::LruCache::::unbounded()); - - let plru = pingora_lru::Lru::<(), 10>::with_capacity(1000, 100); - // populate first, then we bench access/promotion - for i in 0..WEIGHTS.len() { - lru.lock().put(i as u64, ()); - } - for i in 0..WEIGHTS.len() { - plru.admit(i as u64, (), 1); +/// Build a weight distribution where the first `hot_count` items have +/// `hot_weight`x the access probability. +fn make_weights(n: usize, hot_count: usize, hot_weight: usize) -> Vec { + let mut weights = vec![1usize; n]; + for w in weights.iter_mut().take(hot_count) { + *w = hot_weight; } + weights +} - // single thread - let mut rng = thread_rng(); - let dist = WeightedIndex::new(WEIGHTS).unwrap(); +fn bench_config(label: &str, items: usize, shards: usize, hot_pct: usize, hot_weight: usize) { + let hot_count = items * hot_pct / 100; + bench_config_abs(label, items, shards, hot_count, hot_weight); +} - let before = Instant::now(); - for _ in 0..ITERATIONS { - lru.lock().get(&(dist.sample(&mut rng) as u64)); - } - let elapsed = before.elapsed(); +fn bench_config_abs(label: &str, items: usize, shards: usize, hot_count: usize, hot_weight: usize) { println!( - "lru promote total {elapsed:?}, {:?} avg per operation", - elapsed / ITERATIONS as u32 + "\n=== {label}: {items} items, {shards} shards ({} per shard), \ + {hot_count} items are {hot_weight}x hotter ===", + items / shards ); - let before = Instant::now(); - for _ in 0..ITERATIONS { - plru.promote(dist.sample(&mut rng) as u64); - } - let elapsed = before.elapsed(); - println!( - "pingora lru promote total {elapsed:?}, {:?} avg per operation", - elapsed / ITERATIONS as u32 - ); + let weights = make_weights(items, hot_count, hot_weight); + let dist = Arc::new(WeightedIndex::new(&weights).unwrap()); - let before = Instant::now(); - for _ in 0..ITERATIONS { - plru.promote_top_n(dist.sample(&mut rng) as u64, 10); + match shards { + 10 => bench_shards::<10>(items, &dist), + 32 => bench_shards::<32>(items, &dist), + _ => panic!("unsupported shard count: {shards}"), } - let elapsed = before.elapsed(); - println!( - "pingora lru promote_top_10 total {elapsed:?}, {:?} avg per operation", - elapsed / ITERATIONS as u32 - ); +} - // concurrent - - let lru = Arc::new(lru); - let mut handlers = vec![]; - for i in 0..THREADS { - let lru = lru.clone(); - let handler = thread::spawn(move || { - let mut rng = thread_rng(); - let dist = WeightedIndex::new(WEIGHTS).unwrap(); - let before = Instant::now(); - for _ in 0..ITERATIONS { - lru.lock().get(&(dist.sample(&mut rng) as u64)); - } - let elapsed = before.elapsed(); - println!( - "lru promote total {elapsed:?}, {:?} avg per operation thread {i}", - elapsed / ITERATIONS as u32 - ); - }); - handlers.push(handler); +/// Populate a fresh LRU with `items` entries. +fn make_lru(items: usize) -> pingora_lru::Lru<(), N> { + let lru = pingora_lru::Lru::<(), N>::with_capacity(items, items / N); + for i in 0..items { + lru.admit(i as u64, (), 1); } - for thread in handlers { - thread.join().unwrap(); + lru +} + +fn bench_shards(items: usize, dist: &Arc>) { + // Each variant gets a fresh LRU to avoid state contamination from + // prior runs warming hot items to the head. + + // --- Single-threaded --- + println!(" Single-threaded:"); + { + let lru = make_lru::(items); + let mut rng = thread_rng(); + let before = Instant::now(); + for _ in 0..ITERATIONS { + lru.promote(dist.sample(&mut rng) as u64); + } + let elapsed = before.elapsed(); + println!( + " promote: {elapsed:?} total, {:?} avg", + elapsed / ITERATIONS as u32, + ); } - let plru = Arc::new(plru); - - let mut handlers = vec![]; - for i in 0..THREADS { - let plru = plru.clone(); - let handler = thread::spawn(move || { - let mut rng = thread_rng(); - let dist = WeightedIndex::new(WEIGHTS).unwrap(); - let before = Instant::now(); - for _ in 0..ITERATIONS { - plru.promote(dist.sample(&mut rng) as u64); - } - let elapsed = before.elapsed(); - println!( - "pingora lru promote total {elapsed:?}, {:?} avg per operation thread {i}", - elapsed / ITERATIONS as u32 - ); - }); - handlers.push(handler); + for top_n in [0, 3, 10, 50, 100] { + let lru = make_lru::(items); + let mut rng = thread_rng(); + let before = Instant::now(); + for _ in 0..ITERATIONS { + lru.promote_top_n(dist.sample(&mut rng) as u64, top_n); + } + let elapsed = before.elapsed(); + println!( + " promote_top_{top_n:<3} {elapsed:?} total, {:?} avg", + elapsed / ITERATIONS as u32, + ); } - for thread in handlers { - thread.join().unwrap(); + + // --- Multi-threaded --- + println!(" {THREADS}-threaded:"); + + { + let lru = Arc::new(make_lru::(items)); + let barrier = Arc::new(Barrier::new(THREADS)); + let mut handlers = vec![]; + for _ in 0..THREADS { + let lru = lru.clone(); + let dist = Arc::clone(dist); + let barrier = barrier.clone(); + handlers.push(thread::spawn(move || { + let mut rng = thread_rng(); + barrier.wait(); + let before = Instant::now(); + for _ in 0..ITERATIONS { + lru.promote(dist.sample(&mut rng) as u64); + } + before.elapsed() + })); + } + let elapsed: Vec<_> = handlers.into_iter().map(|h| h.join().unwrap()).collect(); + let avg = elapsed.iter().sum::() / THREADS as u32; + println!( + " promote: avg {avg:?}, {:?} avg per op", + avg / ITERATIONS as u32, + ); } - let mut handlers = vec![]; - for i in 0..THREADS { - let plru = plru.clone(); - let handler = thread::spawn(move || { - let mut rng = thread_rng(); - let dist = WeightedIndex::new(WEIGHTS).unwrap(); - let before = Instant::now(); - for _ in 0..ITERATIONS { - plru.promote_top_n(dist.sample(&mut rng) as u64, 10); - } - let elapsed = before.elapsed(); - println!( - "pingora lru promote_top_10 total {elapsed:?}, {:?} avg per operation thread {i}", - elapsed / ITERATIONS as u32 - ); - }); - handlers.push(handler); + for top_n in [0, 3, 10, 50, 100] { + let lru = Arc::new(make_lru::(items)); + let barrier = Arc::new(Barrier::new(THREADS)); + let mut handlers = vec![]; + for _ in 0..THREADS { + let lru = lru.clone(); + let dist = Arc::clone(dist); + let barrier = barrier.clone(); + handlers.push(thread::spawn(move || { + let mut rng = thread_rng(); + barrier.wait(); + let before = Instant::now(); + for _ in 0..ITERATIONS { + lru.promote_top_n(dist.sample(&mut rng) as u64, top_n); + } + before.elapsed() + })); + } + let elapsed: Vec<_> = handlers.into_iter().map(|h| h.join().unwrap()).collect(); + let avg = elapsed.iter().sum::() / THREADS as u32; + println!( + " promote_top_{top_n:<3} avg {avg:?}, {:?} avg per op", + avg / ITERATIONS as u32, + ); } - for thread in handlers { - thread.join().unwrap(); +} + +fn main() { + // Benchmark tiers to simulate production-level data sizes: + // Small = original bench scale (10 items/shard) + // Typical = ~100K items/shard (3.2M total across 32 shards) + // Large = ~500K items/shard (16M total) — gated behind + // BENCH_LARGE=1 to avoid OOM on CI runners (~1.5GB heap) + // + // Note: the Typical tier allocates ~150MB per make_lru() call. With + // multiple variants (promote + 5 top_n values) × configs, total peak + // memory is ~1GB. Well within CI limits but notable for constrained machines. + + // Original benchmark scale (100 items, 10 shards = 10 per shard) + bench_config("Small (original bench scale)", 100, 10, 10, 100); + + // Typical (~100K items/shard), 10% hot + bench_config("Typical (100K/shard, 10% hot)", 3_200_000, 32, 10, 100); + + // Typical (~100K items/shard), heavy-hitter: only 10 items dominate + // Simulates viral content / popular API endpoints where a handful of + // assets receive the vast majority of traffic. + bench_config_abs( + "Typical (100K/shard, 10 heavy hitters)", + 3_200_000, + 32, + 10, + 10_000, + ); + + // Typical (~100K items/shard), moderate hot set: 100 items dominate + bench_config_abs( + "Typical (100K/shard, 100 heavy hitters)", + 3_200_000, + 32, + 100, + 10_000, + ); + + // Large (~500K items/shard, ~1.5GB heap) + if std::env::var("BENCH_LARGE").is_ok() { + bench_config("Large (500K/shard, 10% hot)", 16_000_000, 32, 10, 100); + } else { + println!("\n=== Skipping large bench (set BENCH_LARGE=1 to enable) ==="); } } diff --git a/pingora-lru/src/lib.rs b/pingora-lru/src/lib.rs index 23728c4f..b2c5a642 100644 --- a/pingora-lru/src/lib.rs +++ b/pingora-lru/src/lib.rs @@ -25,11 +25,16 @@ use linked_list::{LinkedList, LinkedListIter}; use hashbrown::HashMap; use parking_lot::RwLock; +use rand::Rng; use std::sync::atomic::{AtomicUsize, Ordering}; /// The LRU with `N` shards pub struct Lru { units: [RwLock>; N], + /// Lock-free `Relaxed` shadow of each shard's item count, backing + /// [`Lru::shard_len`] and the P2C selection in [`Lru::evict_to_limit`]. + /// Maintained alongside [`Lru::len`] at every count-mutating site. + shard_lens: [AtomicUsize; N], weight: AtomicUsize, weight_limit: usize, len_watermark: Option, @@ -59,11 +64,16 @@ impl Lru { ) -> Self { // use the unsafe code from ArrayVec just to init the array let mut units = arrayvec::ArrayVec::<_, N>::new(); + let mut shard_lens = arrayvec::ArrayVec::<_, N>::new(); for _ in 0..N { units.push(RwLock::new(LruUnit::with_capacity(capacity))); + shard_lens.push(AtomicUsize::new(0)); } Lru { units: units.into_inner().map_err(|_| "").unwrap(), + shard_lens: shard_lens + .into_inner() + .expect("shard_lens ArrayVec filled with exactly N elements"), weight: AtomicUsize::new(0), weight_limit, len_watermark, @@ -73,6 +83,23 @@ impl Lru { } } + /// Increment item-count bookkeeping for `shard`. Both atomics use + /// `Relaxed`; called while holding the shard write lock so that + /// `len` and `shard_lens[shard]` advance in lockstep. + #[inline] + fn incr_count(&self, shard: usize) { + self.len.fetch_add(1, Ordering::Relaxed); + self.shard_lens[shard].fetch_add(1, Ordering::Relaxed); + } + + /// Decrement item-count bookkeeping for `shard`. See + /// [`Self::incr_count`]. + #[inline] + fn decr_count(&self, shard: usize) { + self.len.fetch_sub(1, Ordering::Relaxed); + self.shard_lens[shard].fetch_sub(1, Ordering::Relaxed); + } + /// Admit the key value to the [Lru] /// /// Return the shard index which the asset is added to @@ -91,7 +118,7 @@ impl Lru { self.weight.fetch_sub(old_weight, Ordering::Relaxed); } else { // Assume old_weight == 0 means a new item is admitted - self.len.fetch_add(1, Ordering::Relaxed); + self.incr_count(shard); } } shard @@ -128,9 +155,18 @@ impl Lru { /// Promote to the top n of the LRU /// - /// This function is a bit more efficient in terms of reducing lock contention because it - /// will acquire a write lock only if the key is outside top n but only acquires a read lock - /// when the key is already in the top n. + /// This function acquires a read lock first to check if the key is already + /// in the top `n` positions. If so, it returns early without a write lock. + /// Otherwise it falls through to a write lock for the actual promotion. + /// + /// **Performance note**: this optimization only helps when `n` covers a + /// significant fraction of the shard. At production scale (~100K+ items + /// per shard), hot items are rarely in the top N positions, so the + /// read-lock scan is usually wasted work that adds latency without + /// reducing contention. Benchmarks (`cargo bench --bench bench_lru`) + /// show that plain [`promote()`](Self::promote) is faster at scale. + /// Consider using `promote()` directly unless profiling shows a clear + /// benefit for your workload. /// /// Return false if the item doesn't exist pub fn promote_top_n(&self, key: u64, top: usize) -> bool { @@ -141,14 +177,23 @@ impl Lru { unit.write().access(key) } - /// Evict at most one item from the given shard + /// Evict at most one item from the given shard, identified by the + /// hash-like `shard` seed (mapped into `0..N` via `% N`). /// - /// Return the evicted asset and its size if there is anything to evict + /// Return the evicted asset and its size if there is anything to evict. pub fn evict_shard(&self, shard: u64) -> Option<(T, usize)> { - let evicted = self.units[get_shard(shard, N)].write().evict(); + self.evict_shard_at(get_shard(shard, N)) + } + + /// Evict at most one item from the shard at index `shard` (in `0..N`). + /// Internal entry point that skips the `% N` round-trip in + /// [`Self::evict_shard`]. + fn evict_shard_at(&self, shard: usize) -> Option<(T, usize)> { + assert!(shard < N); + let evicted = self.units[shard].write().evict(); if let Some((_, weight)) = evicted.as_ref() { self.weight.fetch_sub(*weight, Ordering::Relaxed); - self.len.fetch_sub(1, Ordering::Relaxed); + self.decr_count(shard); self.evicted_weight.fetch_add(*weight, Ordering::Relaxed); self.evicted_len.fetch_add(1, Ordering::Relaxed); } @@ -159,43 +204,92 @@ impl Lru { /// /// Return a list of evicted items. /// - /// The evicted items are randomly selected from all the shards. + /// Each iteration selects the shard to evict from using the "power of two + /// choices" strategy: two shards are picked uniformly at random and the + /// one with more items is chosen (see + /// ). This biases + /// eviction toward longer shards and drives [`Self::shard_len`] toward a + /// uniform distribution, which keeps per-shard serialization cost (e.g. + /// `pingora_cache::eviction::lru::Manager::serialize_shard`) bounded. + /// + /// Selection is by item count, not weight, even when eviction is + /// triggered by `weight_limit`. With heavily skewed item weights this + /// may evict more items than a weight-biased policy to reach the same + /// total weight — the tradeoff is intentional in favor of bounded + /// per-shard serialization cost. + /// + /// O(1) per iteration in the common case. If the chosen shard is + /// empty when we acquire its write lock (the Relaxed shadow may + /// not always reflect actual emptiness, and P2C may tie-break to + /// an empty shard when all shadow lengths are equal), we linearly + /// probe successive shard indices until one yields an item or we + /// wrap back to the starting shard — at which point every shard + /// was observed empty and we exit. Bounded by at most N probes + /// per outer iteration. pub fn evict_to_limit(&self) -> Vec<(T, usize)> { + self.evict_to_limit_with_rng(&mut rand::thread_rng()) + } + + /// Internal entry point for [`Self::evict_to_limit`] that lets tests + /// inject a seeded RNG for deterministic P2C selection. + fn evict_to_limit_with_rng(&self, rng: &mut R) -> Vec<(T, usize)> { let mut evicted = vec![]; let mut initial_weight = self.weight(); let mut initial_len = self.len(); - let mut shard_seed = rand::random(); // start from a random shard - let mut empty_shard = 0; - - // Entries can be admitted or removed from the LRU by others during the loop below - // Track initial size not to over evict due to entries admitted after the loop starts - // self.weight() / self.len() is also used not to over evict - // due to entries already removed by others - while ((initial_weight > self.weight_limit && self.weight() > self.weight_limit) + + // Transient over-limit weight can persist until the next + // admit/increment_weight call, which is acceptable because the + // next admission will re-trigger eviction. + while (initial_weight > self.weight_limit && self.weight() > self.weight_limit) || self .len_watermark - .is_some_and(|w| initial_len > w && self.len() > w)) - && empty_shard < N + .is_some_and(|w| initial_len > w && self.len() > w) { - if let Some(i) = self.evict_shard(shard_seed) { - initial_weight -= i.1; - initial_len = initial_len.saturating_sub(1); - evicted.push(i) + // Power of two choices: pick the longer of two random shards. + // N == 1 short-circuits the redundant second roll. + let start = if N <= 1 { + 0 } else { - empty_shard += 1; + let a = rng.gen_range(0..N); + let b = rng.gen_range(0..N); + if self.shard_len(a) >= self.shard_len(b) { + a + } else { + b + } + }; + // Try the chosen shard first; on a miss (empty or raced), + // linearly probe successive indices. Wrapping back to + // `start` means every shard was observed empty, so we exit. + let mut shard = start; + let evicted_one = loop { + if let Some(item) = self.evict_shard_at(shard) { + break Some(item); + } + shard = (shard + 1) % N; + if shard == start { + break None; + } + }; + match evicted_one { + Some(i) => { + initial_weight = initial_weight.saturating_sub(i.1); + initial_len = initial_len.saturating_sub(1); + evicted.push(i); + } + None => break, } - // move on to the next shard - shard_seed += 1; } evicted } /// Remove the given asset. pub fn remove(&self, key: u64) -> Option<(T, usize)> { - let removed = self.units[get_shard(key, N)].write().remove(key); + let shard = get_shard(key, N); + let removed = self.units[shard].write().remove(key); if let Some((_, weight)) = removed.as_ref() { self.weight.fetch_sub(*weight, Ordering::Relaxed); - self.len.fetch_sub(1, Ordering::Relaxed); + self.decr_count(shard); } removed } @@ -204,12 +298,10 @@ impl Lru { /// /// Useful to recreate an LRU in most-to-least order pub fn insert_tail(&self, key: u64, data: T, weight: usize) -> bool { - if self.units[get_shard(key, N)] - .write() - .insert_tail(key, data, weight) - { + let shard = get_shard(key, N); + if self.units[shard].write().insert_tail(key, data, weight) { self.weight.fetch_add(weight, Ordering::Relaxed); - self.len.fetch_add(1, Ordering::Relaxed); + self.incr_count(shard); true } else { false @@ -226,7 +318,25 @@ impl Lru { self.units[get_shard(key, N)].read().peek_weight(key) } + /// Peek at the least-recently-used item in the given shard without removing it. + /// + /// Returns a clone of the data and the weight, or `None` if the shard is empty + /// or `shard >= N`. + pub fn peek_lru(&self, shard: usize) -> Option<(T, usize)> + where + T: Clone, + { + self.units + .get(shard)? + .read() + .peek_lru() + .map(|(data, weight)| (data.clone(), weight)) + } + /// Return the current total weight. + /// + /// Lock-free `Relaxed` load. Best-effort: not synchronized with + /// concurrent admissions or evictions on other threads. pub fn weight(&self) -> usize { self.weight.load(Ordering::Relaxed) } @@ -261,9 +371,15 @@ impl Lru { N } - /// Get the number of items inside a shard + /// Get the number of items inside a shard. + /// + /// Lock-free `Relaxed` load from a per-shard atomic shadow. Best-effort: + /// there is no cross-thread ordering between this and [`Self::len`], and + /// `Σ shard_len(i)` is not guaranteed to equal [`Self::len`] at any + /// given instant. Suitable for eviction-balance heuristics and + /// observability; not suitable for synchronization. pub fn shard_len(&self, shard: usize) -> usize { - self.units[shard].read().len() + self.shard_lens[shard].load(Ordering::Relaxed) } /// Get the weight (total size) inside a shard @@ -374,6 +490,19 @@ impl LruUnit { (node.data, node.weight) }) } + + /// Peek at the least-recently-used item without removing it. + /// + /// Returns a reference to the data and weight of the tail item, or `None` + /// if empty. + pub fn peek_lru(&self) -> Option<(&T, usize)> { + self.order + .tail() + .and_then(|idx| self.order.peek(idx)) + .and_then(|key| self.lookup_table.get(&key)) + .map(|node| (&node.data, node.weight)) + } + // TODO: scan the tail up to K elements to decide which ones to evict pub fn remove(&mut self, key: u64) -> Option<(T, usize)> { @@ -400,6 +529,7 @@ impl LruUnit { true } + #[cfg(test)] pub fn len(&self) -> usize { assert_eq!(self.lookup_table.len(), self.order.len()); self.lookup_table.len() @@ -583,7 +713,6 @@ mod test_lru { assert_eq!(lru.len(), 6); let evicted = lru.evict_to_limit(); - // NOTE: there is a low chance this test would fail see the TODO in evict_to_limit assert_eq!(lru.weight(), 6); assert_eq!(lru.len(), 3); assert_eq!(lru.evicted_weight(), 6); @@ -678,6 +807,119 @@ mod test_lru { assert!(!lru.insert_tail(6, 6, 7)); } + #[test] + fn test_evict_to_limit_p2c_bias() { + use rand::rngs::StdRng; + use rand::SeedableRng; + + // Shard 0 starts with 50 items, shard 1 with 10 (all weight 1). + // weight_limit=30 forces 30 evictions. P2C-by-length should pick + // shard 0 (the longer one) most of the time, driving toward + // balance. Expected share from shard 0: P2C ≈ 0.75 (P(shard 0) + // = 3/4 per pick while it stays longer), uniform ≈ 0.50, + // always-shortest ≈ 0.67 (capped by shard 1's 10 items), + // always-longest ≈ 1.0. The (0.65..0.95) window distinguishes + // P2C from uniform; the upper bound catches a degenerate + // always-longest regression. + const TRIALS: u64 = 50; + let mut total_from_shard0 = 0usize; + let mut total_evicted = 0usize; + + for seed in 0..TRIALS { + let lru = Lru::::with_capacity(30, 64); + for k in 0..50u64 { + // even keys → shard 0 + lru.admit(k * 2, k * 2, 1); + } + for k in 0..10u64 { + // odd keys → shard 1 + lru.admit(k * 2 + 1, k * 2 + 1, 1); + } + assert_eq!(lru.weight(), 60); + + let mut rng = StdRng::seed_from_u64(seed); + let evicted = lru.evict_to_limit_with_rng(&mut rng); + assert!( + lru.weight() <= 30, + "post-eviction weight {} exceeds limit", + lru.weight() + ); + total_from_shard0 += evicted.iter().filter(|(k, _)| k % 2 == 0).count(); + total_evicted += evicted.len(); + } + + assert!(total_evicted > 1000, "too few evictions: {total_evicted}"); + let share = total_from_shard0 as f64 / total_evicted as f64; + assert!( + (0.65..0.95).contains(&share), + "expected shard-0 eviction share in 0.65..0.95 (P2C ≈ 0.75); got {share}" + ); + } + + #[test] + fn test_evict_to_limit_break_on_empty_shards_over_limit() { + // Force `weight` above the limit while every shard is empty + // (simulating bookkeeping skew). The linear probe must wrap + // around all N shards and exit cleanly. + let lru = Lru::::with_capacity(10, 16); + lru.weight.fetch_add(100, Ordering::Relaxed); + assert_eq!(lru.evict_to_limit().len(), 0); + } + + #[test] + fn test_watermark_eviction_with_zero_weight_items() { + // All items have weight 0 so the weight-limit guard never fires; + // only the length watermark drives eviction. P2C-by-length should + // still reach the watermark regardless of weight values. + let lru = Lru::::with_capacity_and_watermark(usize::MAX / 2, 10, Some(2)); + for k in 0..6u64 { + lru.insert_tail(k, k, 0); + } + assert_eq!(lru.len(), 6); + assert_eq!(lru.weight(), 0); + let evicted = lru.evict_to_limit(); + assert_eq!(lru.len(), 2); + assert_eq!(evicted.len(), 4); + } + + #[test] + fn test_evict_to_limit_with_mostly_empty_shards() { + // 7/8 shards empty: both random rolls land on empty shards ~77% + // of the time, exercising the linear-probe fallback heavily. + let lru = Lru::::with_capacity(2, 16); + for k in 0..8u64 { + // multiples of 8 hash to shard 0 + lru.admit(k * 8, k * 8, 1); + } + assert_eq!(lru.weight(), 8); + + let evicted = lru.evict_to_limit(); + assert_eq!(lru.weight(), 2); + assert_eq!(evicted.len(), 6); + assert!(evicted.iter().all(|(k, _)| k % 8 == 0)); + } + + #[test] + fn test_evict_to_limit_below_limit_returns_immediately() { + // Smoke test: outer guard short-circuits when already under limit. + let lru = Lru::::with_capacity(0, 16); + assert_eq!(lru.evict_to_limit().len(), 0); + } + + #[test] + fn test_evict_to_limit_n1() { + // N=1 is a trivial special case in the selection logic; ensure + // basic eviction still works. + let lru = Lru::::with_capacity(2, 16); + for k in 0..5u64 { + lru.admit(k, k, 1); + } + assert_eq!(lru.weight(), 5); + let evicted = lru.evict_to_limit(); + assert_eq!(lru.weight(), 2); + assert_eq!(evicted.len(), 3); + } + #[test] fn test_watermark_eviction() { const WEIGHT_LIMIT: usize = usize::MAX / 2; @@ -696,6 +938,29 @@ mod test_lru { assert_eq!(evicted.len(), 2); assert_eq!(lru.evicted_len(), 2); } + + #[test] + fn test_peek_lru() { + let lru = Lru::::with_capacity(10, 10); + + // empty shard + assert!(lru.peek_lru(0).is_none()); + + lru.admit(1, 10, 1); + assert_eq!(lru.peek_lru(0).unwrap(), (10, 1)); + + lru.admit(2, 20, 2); + // key 1 is LRU tail + assert_eq!(lru.peek_lru(0).unwrap(), (10, 1)); + + // promote key 1 + lru.promote(1); + // key 2 is now LRU tail + assert_eq!(lru.peek_lru(0).unwrap(), (20, 2)); + + // out-of-bounds returns None + assert!(lru.peek_lru(999).is_none()); + } } #[cfg(test)] @@ -865,4 +1130,33 @@ mod test_lru_unit { assert_eq!(lru.used_weight(), 1 + 3 + 4 + 5); assert_lru(&lru, &[2, 3, 4, 5]); } + + #[test] + fn test_peek_lru() { + let mut lru = LruUnit::with_capacity(10); + + // empty returns None + assert!(lru.peek_lru().is_none()); + + // single item is both head and tail + lru.admit(1, 10, 1); + let (data, weight) = lru.peek_lru().unwrap(); + assert_eq!(*data, 10); + assert_eq!(weight, 1); + + // second admission pushes first to tail + lru.admit(2, 20, 2); + let (data, _) = lru.peek_lru().unwrap(); + assert_eq!(*data, 10); // key 1 is LRU tail + + // promote key 1 — now key 2 is tail + lru.access(1); + let (data, _) = lru.peek_lru().unwrap(); + assert_eq!(*data, 20); // key 2 is now LRU tail + + // peek doesn't remove + assert!(lru.peek_lru().is_some()); + assert!(lru.peek(1).is_some()); + assert!(lru.peek(2).is_some()); + } } diff --git a/pingora-prometheus/Cargo.toml b/pingora-prometheus/Cargo.toml new file mode 100644 index 00000000..d9701213 --- /dev/null +++ b/pingora-prometheus/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "pingora-prometheus" +version = "0.8.0" +authors = ["Pingora Team at Cloudflare "] +license = "Apache-2.0" +edition = "2021" +repository = "https://github.com/cloudflare/pingora" +categories = ["asynchronous", "network-programming"] +keywords = ["async", "http", "prometheus", "pingora"] +description = """ +A Prometheus metrics HTTP server for pingora services. +""" + +[lib] +name = "pingora_prometheus" +path = "src/lib.rs" + +[dependencies] +pingora-core = { version = "0.8.0", path = "../pingora-core", default-features = false } +prometheus = "0.14" +async-trait = { workspace = true } +http = { workspace = true } diff --git a/pingora-prometheus/src/lib.rs b/pingora-prometheus/src/lib.rs new file mode 100644 index 00000000..cfd90b89 --- /dev/null +++ b/pingora-prometheus/src/lib.rs @@ -0,0 +1,131 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![warn(clippy::all)] + +//! A Prometheus metrics HTTP server for [pingora](https://docs.rs/pingora) services. +//! +//! This crate provides [`PrometheusHttpApp`] and [`PrometheusServer`], which serve +//! all [static metrics](https://docs.rs/prometheus/latest/prometheus/index.html#static-metrics) +//! collected via the [`prometheus`] crate as an HTTP endpoint. +//! +//! # Example +//! +//! ```rust,ignore +//! use pingora_core::services::listening::Service; +//! use pingora_prometheus::new_prometheus_server; +//! +//! let mut prometheus_service = Service::new( +//! "Prometheus HTTP".to_string(), +//! new_prometheus_server(), +//! ); +//! prometheus_service.add_tcp("127.0.0.1:6150"); +//! server.add_service(prometheus_service); +//! ``` +//! +//! Or use the convenience function: +//! +//! ```rust,ignore +//! let mut prometheus_service = pingora_prometheus::prometheus_http_service(); +//! prometheus_service.add_tcp("127.0.0.1:6150"); +//! server.add_service(prometheus_service); +//! ``` + +use async_trait::async_trait; +use http::Response; +use prometheus::{Encoder, TextEncoder}; + +use pingora_core::apps::http_app::{HttpServer, ServeHttp}; +use pingora_core::modules::http::compression::ResponseCompressionBuilder; +use pingora_core::protocols::http::ServerSession; +use pingora_core::services::listening::Service; + +/// Re-export of the [`prometheus`] crate. +/// +/// Use this re-export to ensure your metrics are registered in the same +/// global registry that [`PrometheusHttpApp`] gathers from, avoiding +/// version mismatches that would cause metrics to silently not appear. +/// +/// # Example +/// +/// ```rust,ignore +/// use pingora_prometheus::prometheus::{self, register_int_counter, IntCounter}; +/// use once_cell::sync::Lazy; +/// +/// static REQUESTS: Lazy = Lazy::new(|| { +/// register_int_counter!("requests_total", "Total requests").unwrap() +/// }); +/// ``` +pub use prometheus; + +/// An HTTP application that reports Prometheus metrics. +/// +/// This application will report all the [static metrics](https://docs.rs/prometheus/latest/prometheus/index.html#static-metrics) +/// collected via the [Prometheus](https://docs.rs/prometheus/) crate. +/// +/// Currently serves metrics on all request paths. By convention, Prometheus +/// scrapers expect metrics at `/metrics`. Since this app is typically bound +/// to a dedicated listener address, this works in practice, but callers +/// should be aware of this if sharing the listener with other routes. +// TODO: consider restricting to `/metrics` and returning 404 for other paths +pub struct PrometheusHttpApp; + +#[async_trait] +impl ServeHttp for PrometheusHttpApp { + async fn response(&self, _http_session: &mut ServerSession) -> Response> { + let encoder = TextEncoder::new(); + let metric_families = prometheus::gather(); + let mut buffer = vec![]; + encoder.encode(&metric_families, &mut buffer).unwrap(); + Response::builder() + .status(200) + .header(http::header::CONTENT_TYPE, encoder.format_type()) + .header(http::header::CONTENT_LENGTH, buffer.len()) + .body(buffer) + .unwrap() + } +} + +/// The [`HttpServer`] for [`PrometheusHttpApp`]. +/// +/// This type provides the functionality of [`PrometheusHttpApp`] with gzip +/// compression enabled (level 7). +pub type PrometheusServer = HttpServer; + +/// Create a new [`PrometheusServer`] with compression enabled. +pub fn new_prometheus_server() -> PrometheusServer { + let mut server = PrometheusServer::new_app(PrometheusHttpApp); + // enable gzip level 7 compression + server.add_module(ResponseCompressionBuilder::enable(7)); + server +} + +/// Create a Prometheus HTTP [`Service`] ready to have endpoints added. +/// +/// This is a convenience function that creates a [`Service`] wrapping a +/// [`PrometheusServer`] with compression enabled. +/// +/// # Example +/// +/// ```rust,ignore +/// let mut prometheus_service = pingora_prometheus::prometheus_http_service(); +/// prometheus_service.add_tcp("127.0.0.1:6150"); +/// server.add_service(prometheus_service); +/// ``` +pub fn prometheus_http_service() -> Service { + Service::new( + "Prometheus metric HTTP".to_string(), + new_prometheus_server(), + ) +} diff --git a/pingora-proxy/Cargo.toml b/pingora-proxy/Cargo.toml index 1f367d89..e1cc1cbb 100644 --- a/pingora-proxy/Cargo.toml +++ b/pingora-proxy/Cargo.toml @@ -36,17 +36,21 @@ regex = "1" rand = "0.8" [dev-dependencies] -reqwest = { version = "0.11", features = [ +reqwest = { version = "0.12", features = [ "gzip", "rustls-tls", + "http2", ], default-features = false } httparse = { workspace = true } tokio-test = "0.4" env_logger = "0.11" -hyper = "0.14" -tokio-tungstenite = "0.20.1" +hyper = { version = "1", features = ["client", "http1", "http2"] } +hyper-util = { version = "0.1", features = ["client-legacy", "http1", "http2"] } +http-body-util = "0.1" +tokio-tungstenite = "0.26" pingora-limits = { version = "0.8.0", path = "../pingora-limits" } pingora-load-balancing = { version = "0.8.0", path = "../pingora-load-balancing", default-features=false } +pingora-prometheus = { version = "0.8.0", path = "../pingora-prometheus" } prometheus = "0" futures-util = "0.3" serde = { version = "1.0", features = ["derive"] } @@ -54,7 +58,7 @@ serde_json = "1.0" serde_yaml = "0.9" [target.'cfg(unix)'.dev-dependencies] -hyperlocal = "0.8" +hyperlocal = "0.9" [features] default = [] @@ -69,8 +73,9 @@ s2n = ["pingora-core/s2n", "pingora-cache/s2n", "any_tls"] openssl_derived = ["any_tls"] any_tls = [] sentry = ["pingora-core/sentry"] +upstream_modules = [] connection_filter = ["pingora-core/connection_filter"] -prometheus = ["pingora-core/prometheus"] +trace = ["pingora-cache/trace"] [[example]] name = "connection_filter" diff --git a/pingora-proxy/examples/gateway.rs b/pingora-proxy/examples/gateway.rs index e320688f..79c1646c 100644 --- a/pingora-proxy/examples/gateway.rs +++ b/pingora-proxy/examples/gateway.rs @@ -129,12 +129,8 @@ fn main() { my_proxy.add_tcp("0.0.0.0:6191"); my_server.add_service(my_proxy); - #[cfg(feature = "prometheus")] - let mut prometheus_service_http = - pingora_core::services::listening::Service::prometheus_http_service(); - #[cfg(feature = "prometheus")] + let mut prometheus_service_http = pingora_prometheus::prometheus_http_service(); prometheus_service_http.add_tcp("127.0.0.1:6192"); - #[cfg(feature = "prometheus")] my_server.add_service(prometheus_service_http); my_server.run_forever(); diff --git a/pingora-proxy/src/lib.rs b/pingora-proxy/src/lib.rs index 52a89cbd..4ce9e5e5 100644 --- a/pingora-proxy/src/lib.rs +++ b/pingora-proxy/src/lib.rs @@ -46,7 +46,7 @@ use pingora_http::{RequestHeader, ResponseHeader}; use std::fmt::Debug; use std::str; use std::sync::{ - atomic::{AtomicBool, Ordering}, + atomic::{AtomicBool, AtomicU64, Ordering}, Arc, }; use std::time::Duration; @@ -119,6 +119,8 @@ where pub server_options: Option, pub h2_options: Option, pub downstream_modules: HttpModules, + #[cfg(feature = "upstream_modules")] + pub upstream_modules: HttpModules, max_retries: usize, process_custom_session: Option>, } @@ -153,6 +155,8 @@ impl HttpProxy { server_options: None, h2_options: None, downstream_modules: HttpModules::new(), + #[cfg(feature = "upstream_modules")] + upstream_modules: HttpModules::new(), max_retries: conf.max_retries, process_custom_session: None, } @@ -184,12 +188,25 @@ where shutdown_flag: Arc::new(AtomicBool::new(false)), server_options, downstream_modules: HttpModules::new(), + #[cfg(feature = "upstream_modules")] + upstream_modules: HttpModules::new(), max_retries: conf.max_retries, process_custom_session: on_custom, h2_options: None, } } + /// Return the number of times a pooled upstream connection was found to contain + /// unexpected data from the server. + pub fn unexpected_data_connection_count(&self) -> u64 { + self.client_upstream.unexpected_data_connection_count() + } + + /// Return a shared reference to the unexpected data connection counter for periodic metric reporting. + pub fn unexpected_data_connection_counter(&self) -> Arc { + self.client_upstream.unexpected_data_connection_counter() + } + /// Initialize the downstream modules for this proxy. /// /// This method must be called after creating an [`HttpProxy`] with [`HttpProxy::new()`] @@ -204,6 +221,8 @@ where { self.inner .init_downstream_modules(&mut self.downstream_modules); + #[cfg(feature = "upstream_modules")] + self.inner.init_upstream_modules(&mut self.upstream_modules); } async fn handle_new_request( @@ -464,6 +483,10 @@ pub struct Session { pub subrequest_spawner: Option, // Downstream filter modules pub downstream_modules_ctx: HttpModuleCtx, + /// Upstream filter modules. These run before `upstream_compression` and see the raw + /// (pre-compression) upstream response body. + #[cfg(feature = "upstream_modules")] + pub upstream_modules_ctx: HttpModuleCtx, /// Upstream response body bytes received (payload only). Set by proxy layer. /// TODO: move this into an upstream session digest for future fields. upstream_body_bytes_received: usize, @@ -477,6 +500,7 @@ impl Session { fn new( downstream_session: impl Into>, downstream_modules: &HttpModules, + #[cfg(feature = "upstream_modules")] upstream_modules: &HttpModules, shutdown_flag: Arc, ) -> Self { Session { @@ -489,6 +513,8 @@ impl Session { subrequest_ctx: None, subrequest_spawner: None, // optionally set later on downstream_modules_ctx: downstream_modules.build_ctx(), + #[cfg(feature = "upstream_modules")] + upstream_modules_ctx: upstream_modules.build_ctx(), upstream_body_bytes_received: 0, upstream_write_pending_time: Duration::ZERO, shutdown_flag, @@ -504,6 +530,8 @@ impl Session { Self::new( Box::new(HttpSession::new_http1(stream)), &modules, + #[cfg(feature = "upstream_modules")] + &HttpModules::new(), Arc::new(AtomicBool::new(false)), ) } @@ -516,10 +544,47 @@ impl Session { Self::new( Box::new(HttpSession::new_http1(stream)), downstream_modules, + #[cfg(feature = "upstream_modules")] + &HttpModules::new(), Arc::new(AtomicBool::new(false)), ) } + /// Run upstream module filters on the given [`HttpTask`]. + /// + /// Upstream modules process each task **before** `upstream_compression` and + /// see the raw (pre-compression) upstream response. Like the downstream + /// module path, `response_trailer_filter` and `response_done_filter` return + /// values are converted to body tasks when present. + #[cfg(feature = "upstream_modules")] + pub async fn upstream_modules_filter_task(&mut self, t: &mut HttpTask) -> Result<()> { + match t { + HttpTask::Header(header, eos) => { + self.upstream_modules_ctx + .response_header_filter(header, *eos) + .await?; + } + HttpTask::Body(body, eos) | HttpTask::UpgradedBody(body, eos) => { + self.upstream_modules_ctx.response_body_filter(body, *eos)?; + } + HttpTask::Trailer(trailers) => { + if let Some(buf) = self + .upstream_modules_ctx + .response_trailer_filter(trailers)? + { + *t = HttpTask::Body(Some(buf), true); + } + } + HttpTask::Done => { + if let Some(buf) = self.upstream_modules_ctx.response_done_filter()? { + *t = HttpTask::Body(Some(buf), true); + } + } + HttpTask::Failed(_) => {} + } + Ok(()) + } + pub fn as_downstream_mut(&mut self) -> &mut HttpSession { &mut self.downstream_session } @@ -587,57 +652,115 @@ impl Session { .await } - pub async fn write_response_tasks(&mut self, mut tasks: Vec) -> Result { - let mut seen_upgraded = self.was_upgraded(); - for task in tasks.iter_mut() { - match task { - HttpTask::Header(resp, end) => { - self.downstream_modules_ctx - .response_header_filter(resp, *end) - .await?; - } - HttpTask::Body(data, end) => { - self.downstream_modules_ctx - .response_body_filter(data, *end)?; - } - HttpTask::UpgradedBody(data, end) => { - seen_upgraded = true; - self.downstream_modules_ctx - .response_body_filter(data, *end)?; - } - HttpTask::Trailer(trailers) => { - if let Some(buf) = self - .downstream_modules_ctx - .response_trailer_filter(trailers)? - { - // Write the trailers into the body if the filter - // returns a buffer. - // - // Note, this will not work if end of stream has already - // been seen or we've written content-length bytes. - // (Trailers should never come after upgraded body) - *task = HttpTask::Body(Some(buf), true); - } - } - HttpTask::Done => { - // `Done` can be sent in certain response paths to mark end - // of response if not already done via trailers or body with - // end flag set. - // If the filter returns body bytes on Done, - // write them into the response. + // Run downstream module response filters on a single task, updating + // `seen_upgraded` to track whether an upgrade has been seen. Used by both + // `send_downstream_proxy_task` and `write_response_tasks`. + async fn downstream_response_task_filter( + &mut self, + task: &mut HttpTask, + seen_upgraded: &mut bool, + ) -> Result<()> { + match task { + HttpTask::Header(resp, end) => { + self.downstream_modules_ctx + .response_header_filter(resp, *end) + .await?; + } + HttpTask::Body(data, end) => { + self.downstream_modules_ctx + .response_body_filter(data, *end)?; + } + HttpTask::UpgradedBody(data, end) => { + *seen_upgraded = true; + self.downstream_modules_ctx + .response_body_filter(data, *end)?; + } + HttpTask::Trailer(trailers) => { + if let Some(buf) = self + .downstream_modules_ctx + .response_trailer_filter(trailers)? + { + // Write the trailers into the body if the filter + // returns a buffer. // // Note, this will not work if end of stream has already // been seen or we've written content-length bytes. - if let Some(buf) = self.downstream_modules_ctx.response_done_filter()? { - if seen_upgraded { - *task = HttpTask::UpgradedBody(Some(buf), true); - } else { - *task = HttpTask::Body(Some(buf), true); - } + // (Trailers should never come after upgraded body) + *task = HttpTask::Body(Some(buf), true); + } + } + HttpTask::Done => { + // `Done` can be sent in certain response paths to mark end + // of response if not already done via trailers or body with + // end flag set. + // If the filter returns body bytes on Done, + // write them into the response. + // + // Note, this will not work if end of stream has already + // been seen or we've written content-length bytes. + if let Some(buf) = self.downstream_modules_ctx.response_done_filter()? { + if *seen_upgraded { + *task = HttpTask::UpgradedBody(Some(buf), true); + } else { + *task = HttpTask::Body(Some(buf), true); } } - _ => { /* Failed */ } } + _ => { /* Failed */ } + } + Ok(()) + } + + /// Queue a downstream proxy task for cancel-safe writing after running + /// downstream module filters. This allows decoupling cache writes from + /// downstream writes. + /// + /// Only works with sessions that support the proxy task API (currently H1). + /// + /// # Panics + /// Panics if the session doesn't support the proxy task API. + /// Use `write_response_tasks()` for sessions that don't support the proxy task API. + pub async fn send_downstream_proxy_task(&mut self, mut task: HttpTask) -> Result<()> { + let mut seen_upgraded = self.was_upgraded(); + self.downstream_response_task_filter(&mut task, &mut seen_upgraded) + .await?; + self.downstream_session.send_downstream_proxy_task(task); + Ok(()) + } + + /// Enable or disable the cancel-safe proxy task API for this session. + /// + /// When disabled, the proxy falls back to the blocking `write_response_tasks` + /// path. This can be called from request filters to opt out on a per-request + /// basis. + pub fn set_proxy_tasks_enabled(&mut self, enabled: bool) { + self.downstream_session.set_proxy_tasks_enabled(enabled); + } + + /// Check if there are pending downstream tasks queued for writing. + /// Used for backpressure - don't queue more cache tasks if we have pending writes. + /// Returns false for sessions that don't support the proxy task API. + pub fn has_pending_downstream_tasks(&self) -> bool { + self.downstream_session.supports_proxy_task_api() + && self.downstream_session.has_pending_downstream_proxy_tasks() + } + + /// Write all queued downstream proxy tasks. This is cancel-safe and can be called + /// in a select! loop while waiting for upstream tasks. + /// For sessions that don't support the proxy task API, this is a no-op. + pub async fn write_downstream_proxy_tasks(&mut self) -> Result { + if self.downstream_session.supports_proxy_task_api() { + self.downstream_session.write_downstream_proxy_tasks().await + } else { + Ok(false) + } + } + + pub async fn write_response_tasks(&mut self, mut tasks: Vec) -> Result { + let mut seen_upgraded = self.was_upgraded(); + for task in tasks.iter_mut() { + self.downstream_response_task_filter(task, &mut seen_upgraded) + .await?; } self.downstream_session.response_duplex_vec(tasks).await } @@ -1030,6 +1153,8 @@ where Some(downstream_session) => Session::new( downstream_session, &self.downstream_modules, + #[cfg(feature = "upstream_modules")] + &self.upstream_modules, self.shutdown_flag.clone(), ), None => return, // bad request @@ -1157,6 +1282,8 @@ where Some(downstream_session) => Session::new( downstream_session, &self.downstream_modules, + #[cfg(feature = "upstream_modules")] + &self.upstream_modules, self.shutdown_flag.clone(), ), None => return None, // bad request diff --git a/pingora-proxy/src/proxy_cache.rs b/pingora-proxy/src/proxy_cache.rs index 43b2ace9..fecf6fbd 100644 --- a/pingora-proxy/src/proxy_cache.rs +++ b/pingora-proxy/src/proxy_cache.rs @@ -487,6 +487,7 @@ where } } + // No enabled() guard: no concurrent upstream can disable cache here. if let Err(e) = session.cache.finish_hit_handler().await { warn!("Error during finish_hit_handler: {}", e); } @@ -1276,7 +1277,7 @@ pub mod range_filter { pub ranges: Vec>, pub boundary: String, total_length: usize, - content_type: Option, + pub content_type: Option, } impl MultiRangeInfo { diff --git a/pingora-proxy/src/proxy_common.rs b/pingora-proxy/src/proxy_common.rs index e1d36f69..6c40760c 100644 --- a/pingora-proxy/src/proxy_common.rs +++ b/pingora-proxy/src/proxy_common.rs @@ -1,3 +1,17 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + /// Possible downstream states during request multiplexing #[derive(Debug, Clone, Copy)] pub(crate) enum DownstreamStateMachine { @@ -36,19 +50,28 @@ impl DownstreamStateMachine { matches!(self, Self::Errored) } - /// Move the state machine to Finished state if `set` is true + /// Move the state machine to Finished state if `set` is true. + /// + /// No-op when the current state is [`Errored`](Self::Errored) — once errored the + /// downstream connection must not be reused, and late upstream chunks arriving + /// via `rx.recv()` must not overwrite that decision. pub fn maybe_finished(&mut self, set: bool) { - if set { + if set && !self.is_errored() { *self = Self::ReadingFinished } } - /// Reset if we should continue reading from the downstream again. - /// Only used with upgraded connections when body mode changes. + /// Reset to [`Reading`](Self::Reading) for upgraded connections when body mode changes. + /// + /// No-op when the current state is [`Errored`](Self::Errored). pub fn reset(&mut self) { - *self = Self::Reading; + if !self.is_errored() { + *self = Self::Reading; + } } + /// Transition to [`Errored`](Self::Errored). This is a terminal state: once entered, + /// no other state transition is permitted and the connection must not be reused. pub fn to_errored(&mut self) { *self = Self::Errored } @@ -97,3 +120,61 @@ impl ResponseStateMachine { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn normal_lifecycle() { + let mut ds = DownstreamStateMachine::new(false); + assert!(ds.is_reading()); + assert!(ds.can_poll()); + assert!(!ds.is_errored()); + + ds.maybe_finished(true); + assert!(!ds.is_reading()); + assert!(ds.is_done()); + assert!(ds.can_poll()); // ReadingFinished still allows polling (for idle) + assert!(!ds.is_errored()); + } + + #[test] + fn errored_is_terminal() { + let mut ds = DownstreamStateMachine::new(false); + ds.to_errored(); + assert!(ds.is_errored()); + assert!(!ds.can_poll()); + assert!(ds.is_done()); + } + + /// `maybe_finished(false)` is always a no-op regardless of state. + #[test] + fn maybe_finished_false_is_noop() { + let mut ds = DownstreamStateMachine::new(false); + ds.to_errored(); + ds.maybe_finished(false); // must not panic + assert!(ds.is_errored()); + assert!(!ds.can_poll()); + } + + /// `maybe_finished(true)` on `Errored` is a no-op — `Errored` is terminal. + #[test] + fn maybe_finished_true_noop_on_errored() { + let mut ds = DownstreamStateMachine::new(false); + ds.to_errored(); + ds.maybe_finished(true); // must not overwrite Errored + assert!(ds.is_errored()); + assert!(!ds.can_poll()); + } + + /// `reset()` on `Errored` is a no-op — `Errored` is terminal. + #[test] + fn reset_noop_on_errored() { + let mut ds = DownstreamStateMachine::new(false); + ds.to_errored(); + ds.reset(); // must not overwrite Errored + assert!(ds.is_errored()); + assert!(!ds.can_poll()); + } +} diff --git a/pingora-proxy/src/proxy_custom.rs b/pingora-proxy/src/proxy_custom.rs index 63079111..31cb3a52 100644 --- a/pingora-proxy/src/proxy_custom.rs +++ b/pingora-proxy/src/proxy_custom.rs @@ -257,7 +257,90 @@ where } } - // returns whether server (downstream) session can be reused + #[allow(clippy::too_many_arguments)] + async fn process_upstream_tasks_custom( + &self, + session: &mut Session, + ctx: &mut SV::CTX, + initial_task: HttpTask, + rx: &mut mpsc::Receiver, + serve_from_cache: &mut ServeFromCache, + range_body_filter: &mut proxy_cache::range_filter::RangeBodyFilter, + response_state: &mut ResponseStateMachine, + ) -> Result> + where + SV: ProxyHttp + Send + Sync, + SV::CTX: Send + Sync, + { + if serve_from_cache.should_discard_upstream() { + // just drain, do we need to do anything else? + return Ok(None); + } + + // Batch: pull as many tasks as we can from rx + let mut tasks = Vec::with_capacity(TASK_BUFFER_SIZE); + tasks.push(initial_task); + while let Ok(task) = rx.try_recv() { + tasks.push(task); + } + + /* run filters before sending to downstream */ + let mut filtered_tasks = Vec::with_capacity(TASK_BUFFER_SIZE); + for mut t in tasks { + if self.revalidate_or_stale(session, &mut t, ctx).await { + serve_from_cache.enable(); + response_state.enable_cached_response(); + // skip downstream filtering entirely as the 304 will not be sent + break; + } + #[cfg(feature = "upstream_modules")] + if let HttpTask::Header(header, end_of_stream) = &t { + self.inner + .adjust_upstream_modules(session, header, *end_of_stream, ctx) + .await?; + } + #[cfg(feature = "upstream_modules")] + session.upstream_modules_filter_task(&mut t).await?; + session.upstream_compression.response_filter(&mut t); + // check error and abort + // otherwise the error is surfaced via write_response_tasks() + if !serve_from_cache.should_send_to_downstream() { + if let HttpTask::Failed(e) = t { + return Err(e); + } + } + filtered_tasks.push( + self.custom_response_filter( + session, + t, + ctx, + serve_from_cache, + range_body_filter, + false, + ) + .await?, + ); + if serve_from_cache.is_miss_header() { + response_state.enable_cached_response(); + } + } + + if !serve_from_cache.should_send_to_downstream() { + // TODO: need to derive response_done from filtered_tasks in case downstream failed already + return Ok(None); + } + + let response_done = session.write_response_tasks(filtered_tasks).await?; + + Ok(Some(response_done)) + } + + // TODO: pre-existing inconsistency with proxy_h1/proxy_h2 to address in a follow-up: + // upstream task rx.recv() branch is missing + // downstream_state.maybe_finished(session.is_body_done()) after processing. proxy_h1 has + // this because upgrade responses can force the body done — since custom upstreams can + // serve H1 downstreams that support upgrades, the same may be needed here. + // Returns whether server (downstream) session can be reused #[allow(clippy::too_many_arguments)] async fn custom_bidirection_down_to_up( &self, @@ -303,6 +386,8 @@ where let mut serve_from_cache = ServeFromCache::new(); let mut range_body_filter = proxy_cache::range_filter::RangeBodyFilter::new(); + let mut next_upstream_task: Option = None; + let mut upstream_custom = true; let mut downstream_custom = true; @@ -361,93 +446,145 @@ where }; }, - task = rx.recv(), if !response_state.upstream_done() => { - debug!("upstream event"); - + // Handle buffered upstream task from previous iteration + task = async { next_upstream_task.take() }, if next_upstream_task.is_some() => { + debug!("buffered upstream event: {:?}", task); if let Some(t) = task { - debug!("upstream event custom: {:?}", t); - if serve_from_cache.should_discard_upstream() { - // just drain, do we need to do anything else? - continue; - } - // pull as many tasks as we can - let mut tasks = Vec::with_capacity(TASK_BUFFER_SIZE); - tasks.push(t); - while let Ok(task) = rx.try_recv() { - tasks.push(task); - } - - /* run filters before sending to downstream */ - let mut filtered_tasks = Vec::with_capacity(TASK_BUFFER_SIZE); - for mut t in tasks { - if self.revalidate_or_stale(session, &mut t, ctx).await { - serve_from_cache.enable(); - response_state.enable_cached_response(); - // skip downstream filtering entirely as the 304 will not be sent - break; - } - session.upstream_compression.response_filter(&mut t); - // check error and abort - // otherwise the error is surfaced via write_response_tasks() - if !serve_from_cache.should_send_to_downstream() { - if let HttpTask::Failed(e) = t { - return Err(e); - } - } - filtered_tasks.push( - self.custom_response_filter(session, t, ctx, - &mut serve_from_cache, - &mut range_body_filter, false).await?); - if serve_from_cache.is_miss_header() { - response_state.enable_cached_response(); - } - } - - if !serve_from_cache.should_send_to_downstream() { - // TODO: need to derive response_done from filtered_tasks in case downstream failed already + let Some(response_done) = self.process_upstream_tasks_custom( + session, + ctx, + t, + &mut rx, + &mut serve_from_cache, + &mut range_body_filter, + &mut response_state, + ).await? else { + // nothing sent downstream e.g. serve_from_cache continue; - } + }; + response_state.maybe_set_upstream_done(response_done); + } else { + debug!("empty upstream event"); + response_state.maybe_set_upstream_done(true); + } + }, + task = rx.recv(), if !response_state.upstream_done() && next_upstream_task.is_none() => { + debug!("upstream event: {:?}", task); + if let Some(t) = task { let upgraded = session.was_upgraded(); - let response_done = session.write_response_tasks(filtered_tasks).await?; + let Some(response_done) = self.process_upstream_tasks_custom( + session, + ctx, + t, + &mut rx, + &mut serve_from_cache, + &mut range_body_filter, + &mut response_state, + ).await? else { + // nothing sent downstream e.g. serve_from_cache + continue; + }; if !upgraded && session.was_upgraded() && downstream_state.can_poll() { // just upgraded, the downstream state should be reset to continue to // poll body trace!("reset downstream state on upgrade"); downstream_state.reset(); } - response_state.maybe_set_upstream_done(response_done); } else { debug!("empty upstream event"); response_state.maybe_set_upstream_done(true); } - } + }, task = serve_from_cache.next_http_task(&mut session.cache, &mut range_body_filter, upgraded), - if !response_state.cached_done() && !downstream_state.is_errored() && serve_from_cache.is_on() => { + if !response_state.cached_done() + && !downstream_state.is_errored() + && serve_from_cache.is_on() + && !session.has_pending_downstream_tasks() => { // backpressure: don't queue if pending writes + let task = self.custom_response_filter(session, task?, ctx, &mut serve_from_cache, &mut range_body_filter, true).await?; - match session.write_response_tasks(vec![task]).await { - Ok(b) => response_state.maybe_set_cache_done(b), - Err(e) => if serve_from_cache.is_miss() { - // give up writing to downstream but wait for upstream cache write to finish - downstream_state.to_errored(); - response_state.maybe_set_cache_done(true); - warn!( - "Downstream Error ignored during caching: {}, {}", - e, - self.inner.request_summary(session, ctx) - ); - continue; - } else { - return Err(e); + + if session.downstream_session.supports_proxy_task_api() { + session.send_downstream_proxy_task(task).await?; + } else { + match session.write_response_tasks(vec![task]).await { + Ok(b) => response_state.maybe_set_cache_done(b), + Err(e) => if serve_from_cache.is_miss() { + // give up writing to downstream but wait for upstream cache write to finish + downstream_state.to_errored(); + response_state.maybe_set_cache_done(true); + warn!( + "Downstream Error ignored during caching: {}, {}", + e, + self.inner.request_summary(session, ctx) + ); + session.downstream_session.on_proxy_failure(e); + continue; + } else { + return Err(e); + } + } + // A storage error can disable cache between cached_done + // being set and here; see the same guard in proxy_h1.rs. + if response_state.cached_done() && session.cache.enabled() { + if let Err(e) = session.cache.finish_hit_handler().await { + warn!("Error during finish_hit_handler: {}", e); + } } } - if response_state.cached_done() { - if let Err(e) = session.cache.finish_hit_handler().await { - warn!("Error during finish_hit_handler: {}", e); + } + + // Write queued downstream proxy tasks while also polling for upstream tasks. + // This allows cache writes to continue even when downstream is stalled. + // + // "Gate" branch: ready(()) resolves immediately, so the guard controls + // whether we enter. This is not a busy-loop because every path through + // the inner select either (a) drains all pending tasks via + // write_downstream_proxy_tasks (making the guard false), (b) stores an + // upstream task in next_upstream_task (making the guard false), or + // (c) blocks on real I/O inside the nested select. + _ = std::future::ready(()), if session.has_pending_downstream_tasks() && next_upstream_task.is_none() => { + tokio::select! { + // Try to write downstream proxy tasks (cancel-safe) + write_result = session.write_downstream_proxy_tasks() => { + match write_result { + Ok(end) => { + response_state.maybe_set_cache_done(end); + // See enabled() guard comment above. + if response_state.cached_done() && session.cache.enabled() { + if let Err(e) = session.cache.finish_hit_handler().await { + warn!("Error during finish_hit_handler: {}", e); + } + } + } + Err(e) => if serve_from_cache.is_miss() { + // give up writing to downstream but wait for upstream cache write to finish + downstream_state.to_errored(); + response_state.maybe_set_cache_done(true); + warn!( + "Downstream write error ignored during caching: {}, {}", + e, + self.inner.request_summary(session, ctx) + ); + session.downstream_session.on_proxy_failure(e); + } else { + return Err(e); + } + } + } + + // Also poll for upstream tasks - if we get one, cancel the write and handle it. + upstream_task = rx.recv(), if !response_state.upstream_done() && serve_from_cache.is_on() && next_upstream_task.is_none() => { + if let Some(t) = upstream_task { + next_upstream_task = Some(t); + continue; + } else { + response_state.maybe_set_upstream_done(true); + } } } } diff --git a/pingora-proxy/src/proxy_h1.rs b/pingora-proxy/src/proxy_h1.rs index 9f04289c..e74309eb 100644 --- a/pingora-proxy/src/proxy_h1.rs +++ b/pingora-proxy/src/proxy_h1.rs @@ -267,6 +267,83 @@ where Ok(()) } + #[allow(clippy::too_many_arguments)] + async fn process_upstream_tasks( + &self, + session: &mut Session, + ctx: &mut SV::CTX, + initial_task: HttpTask, + rx: &mut mpsc::Receiver, + serve_from_cache: &mut ServeFromCache, + range_body_filter: &mut proxy_cache::range_filter::RangeBodyFilter, + response_state: &mut ResponseStateMachine, + ) -> Result> + where + SV: ProxyHttp + Send + Sync, + SV::CTX: Send + Sync, + { + if serve_from_cache.should_discard_upstream() { + // just drain, do we need to do anything else? + return Ok(None); + } + + // Batch: pull as many tasks as we can from rx + let mut tasks = Vec::with_capacity(TASK_BUFFER_SIZE); + tasks.push(initial_task); + // tokio::task::unconstrained because now_or_never may yield None when the future is ready + while let Some(maybe_task) = tokio::task::unconstrained(rx.recv()).now_or_never() { + debug!("upstream event now: {:?}", maybe_task); + if let Some(t) = maybe_task { + tasks.push(t); + } else { + break; // upstream closed + } + } + + /* run filters before sending to downstream */ + let mut filtered_tasks = Vec::with_capacity(TASK_BUFFER_SIZE); + for mut t in tasks { + if self.revalidate_or_stale(session, &mut t, ctx).await { + serve_from_cache.enable(); + response_state.enable_cached_response(); + // skip downstream filtering entirely as the 304 will not be sent + break; + } + #[cfg(feature = "upstream_modules")] + if let HttpTask::Header(header, end_of_stream) = &t { + self.inner + .adjust_upstream_modules(session, header, *end_of_stream, ctx) + .await?; + } + #[cfg(feature = "upstream_modules")] + session.upstream_modules_filter_task(&mut t).await?; + session.upstream_compression.response_filter(&mut t); + let task = self + .h1_response_filter(session, t, ctx, serve_from_cache, range_body_filter, false) + .await?; + if serve_from_cache.is_miss_header() { + response_state.enable_cached_response(); + } + // check error and abort + // otherwise the error is surfaced via write_response_tasks() + if !serve_from_cache.should_send_to_downstream() { + if let HttpTask::Failed(e) = task { + return Err(e); + } + } + filtered_tasks.push(task); + } + + if !serve_from_cache.should_send_to_downstream() { + // TODO: need to derive response_done from filtered_tasks in case downstream failed already + return Ok(None); + } + + let response_done = session.write_response_tasks(filtered_tasks).await?; + + Ok(Some(response_done)) + } + // todo use this function to replace bidirection_1to2() // returns whether this server (downstream) session can be reused async fn proxy_handle_downstream( @@ -329,6 +406,8 @@ where let mut serve_from_cache = proxy_cache::ServeFromCache::new(); let mut range_body_filter = proxy_cache::range_filter::RangeBodyFilter::new(); + let mut next_upstream_task: Option = None; + /* duplex mode without caching * Read body from downstream while reading response from upstream * If response is done, only read body from downstream @@ -424,68 +503,56 @@ where // If tx is closed, the upstream has already finished its job. downstream_state.maybe_finished(tx.is_closed()); debug!("waiting for permit {send_permit:?}, upstream closed {}", tx.is_closed()); - /* No permit, wait on more capacity to avoid starving. + /* No permit, wait on more capacity to avoid starving. * Otherwise this select only blocks on rx, which might send no data * before the entire body is uploaded. * once more capacity arrives we just loop back */ }, - task = rx.recv(), if !response_state.upstream_done() => { - debug!("upstream event: {:?}", task); + // Handle buffered upstream task from previous iteration + task = async { next_upstream_task.take() }, if next_upstream_task.is_some() => { + debug!("buffered upstream event: {:?}", task); if let Some(t) = task { - if serve_from_cache.should_discard_upstream() { - // just drain, do we need to do anything else? - continue; - } - // pull as many tasks as we can - let mut tasks = Vec::with_capacity(TASK_BUFFER_SIZE); - tasks.push(t); - // tokio::task::unconstrained because now_or_never may yield None when the future is ready - while let Some(maybe_task) = tokio::task::unconstrained(rx.recv()).now_or_never() { - debug!("upstream event now: {:?}", maybe_task); - if let Some(t) = maybe_task { - tasks.push(t); - } else { - break; // upstream closed - } - } - - /* run filters before sending to downstream */ - let mut filtered_tasks = Vec::with_capacity(TASK_BUFFER_SIZE); - for mut t in tasks { - if self.revalidate_or_stale(session, &mut t, ctx).await { - serve_from_cache.enable(); - response_state.enable_cached_response(); - // skip downstream filtering entirely as the 304 will not be sent - break; - } - session.upstream_compression.response_filter(&mut t); - let task = self.h1_response_filter(session, t, ctx, - &mut serve_from_cache, - &mut range_body_filter, false).await?; - if serve_from_cache.is_miss_header() { - response_state.enable_cached_response(); - } - // check error and abort - // otherwise the error is surfaced via write_response_tasks() - if !serve_from_cache.should_send_to_downstream() { - if let HttpTask::Failed(e) = task { - return Err(e); - } - } - filtered_tasks.push(task); - } - - if !serve_from_cache.should_send_to_downstream() { - // TODO: need to derive response_done from filtered_tasks in case downstream failed already + let Some(response_done) = self.process_upstream_tasks( + session, + ctx, + t, + &mut rx, + &mut serve_from_cache, + &mut range_body_filter, + &mut response_state, + ).await? else { + // nothing sent downstream e.g. serve_from_cache continue; - } + }; + response_state.maybe_set_upstream_done(response_done); + // unsuccessful upgrade response may force the request done + downstream_state.maybe_finished(session.is_body_done()); + } else { + debug!("empty upstream event"); + response_state.maybe_set_upstream_done(true); + } + }, - // set to downstream + task = rx.recv(), if !response_state.upstream_done() && next_upstream_task.is_none() => { + debug!("upstream event: {:?}", task); + if let Some(t) = task { let upgraded = session.was_upgraded(); - let response_done = session.write_response_tasks(filtered_tasks).await?; + let Some(response_done) = self.process_upstream_tasks( + session, + ctx, + t, + &mut rx, + &mut serve_from_cache, + &mut range_body_filter, + &mut response_state, + ).await? else { + // nothing sent downstream e.g. serve_from_cache + continue; + }; if !upgraded && session.was_upgraded() && downstream_state.can_poll() { + // TODO: write can happen async now // just upgraded, the downstream state should be reset to continue to // poll body trace!("reset downstream state on upgrade"); @@ -502,35 +569,100 @@ where }, task = serve_from_cache.next_http_task(&mut session.cache, &mut range_body_filter, upgraded), - if !response_state.cached_done() && !downstream_state.is_errored() && serve_from_cache.is_on() => { + if !response_state.cached_done() + && !downstream_state.is_errored() + && serve_from_cache.is_on() + && !session.has_pending_downstream_tasks() => { // backpressure: don't queue if pending writes let task = self.h1_response_filter(session, task?, ctx, &mut serve_from_cache, &mut range_body_filter, true).await?; debug!("serve_from_cache task {task:?}"); - match session.write_response_tasks(vec![task]).await { - Ok(b) => response_state.maybe_set_cache_done(b), - Err(e) => if serve_from_cache.is_miss() { - // give up writing to downstream but wait for upstream cache write to finish - downstream_state.to_errored(); - response_state.maybe_set_cache_done(true); - warn!( - "Downstream Error ignored during caching: {}, {}", - e, - self.inner.request_summary(session, ctx) - ); - // This will not be treated as a final error, but we should signal to - // downstream session regardless - session.downstream_session.on_proxy_failure(e); - continue; - } else { - return Err(e); + if session.downstream_session.supports_proxy_task_api() { + session.send_downstream_proxy_task(task).await?; + } else { + match session.write_response_tasks(vec![task]).await { + Ok(b) => response_state.maybe_set_cache_done(b), + Err(e) => if serve_from_cache.is_miss() { + // give up writing to downstream but wait for upstream cache write to finish + downstream_state.to_errored(); + response_state.maybe_set_cache_done(true); + warn!( + "Downstream Error ignored during caching: {}, {}", + e, + self.inner.request_summary(session, ctx) + ); + // This will not be treated as a final error, but we should signal to + // downstream session regardless + session.downstream_session.on_proxy_failure(e); + continue; + } else { + return Err(e); + } + } + // A storage error can disable cache between cached_done + // being set and here; disable() drops the enabled_ctx so + // finish_hit_handler would panic without this guard. + if response_state.cached_done() && session.cache.enabled() { + if let Err(e) = session.cache.finish_hit_handler().await { + warn!("Error during finish_hit_handler: {}", e); + } } } - if response_state.cached_done() { - if let Err(e) = session.cache.finish_hit_handler().await { - warn!("Error during finish_hit_handler: {}", e); + } + + // Write queued downstream proxy tasks while also polling for upstream tasks. + // This allows cache writes to continue even when downstream is stalled. + // + // "Gate" branch: ready(()) resolves immediately, so the guard controls + // whether we enter. This is not a busy-loop because every path through + // the inner select either (a) drains all pending tasks via + // write_downstream_proxy_tasks (making the guard false), (b) stores an + // upstream task in next_upstream_task (making the guard false), or + // (c) blocks on real I/O inside the nested select. + _ = std::future::ready(()), if session.has_pending_downstream_tasks() && next_upstream_task.is_none() => { + tokio::select! { + // Try to write downstream proxy tasks (cancel-safe) + write_result = session.write_downstream_proxy_tasks() => { + match write_result { + Ok(end) => { + response_state.maybe_set_cache_done(end); + // See enabled() guard comment above. + if response_state.cached_done() && session.cache.enabled() { + if let Err(e) = session.cache.finish_hit_handler().await { + warn!("Error during finish_hit_handler: {}", e); + } + } + } + Err(e) => if serve_from_cache.is_miss() { + // give up writing to downstream but wait for upstream cache write to finish + downstream_state.to_errored(); + response_state.maybe_set_cache_done(true); + warn!( + "Downstream write error ignored during caching: {}, {}", + e, + self.inner.request_summary(session, ctx) + ); + // This will not be treated as a final error, but we should signal to + // downstream session regardless + session.downstream_session.on_proxy_failure(e); + } else { + return Err(e); + } + } + } + + // Also poll for upstream tasks - if we get one, cancel the write and handle it. + // Only poll if there is no buffered task already waiting to be processed. + upstream_task = rx.recv(), if !response_state.upstream_done() && serve_from_cache.is_on() && next_upstream_task.is_none() => { + if let Some(t) = upstream_task { + // Store this upstream task to be processed next iteration + next_upstream_task = Some(t); + continue; + } else { + response_state.maybe_set_upstream_done(true); + } } } } diff --git a/pingora-proxy/src/proxy_h2.rs b/pingora-proxy/src/proxy_h2.rs index 0d633e4a..e5030819 100644 --- a/pingora-proxy/src/proxy_h2.rs +++ b/pingora-proxy/src/proxy_h2.rs @@ -89,7 +89,7 @@ where { let mut req = session.req_header().clone(); - if req.version != Version::HTTP_2 { + if req.version != Version::HTTP_2 || session.downstream_session.is_custom() { /* remove H1 specific headers */ // https://github.com/hyperium/h2/blob/d3b9f1e36aadc1a7a6804e2f8e86d3fe4a244b4f/src/proto/streams/send.rs#L72 req.remove_header(&http::header::TRANSFER_ENCODING); @@ -265,6 +265,89 @@ where (server_session_reuse, error) } + #[allow(clippy::too_many_arguments)] + async fn process_upstream_tasks_h2( + &self, + session: &mut Session, + ctx: &mut SV::CTX, + initial_task: HttpTask, + rx: &mut mpsc::Receiver, + serve_from_cache: &mut ServeFromCache, + range_body_filter: &mut proxy_cache::range_filter::RangeBodyFilter, + response_state: &mut ResponseStateMachine, + ) -> Result> + where + SV: ProxyHttp + Send + Sync, + SV::CTX: Send + Sync, + { + if serve_from_cache.should_discard_upstream() { + // just drain, do we need to do anything else? + return Ok(None); + } + + // Batch: pull as many tasks as we can from rx + let mut tasks = Vec::with_capacity(TASK_BUFFER_SIZE); + tasks.push(initial_task); + // tokio::task::unconstrained because now_or_never may yield None when the future is ready + while let Some(maybe_task) = tokio::task::unconstrained(rx.recv()).now_or_never() { + if let Some(t) = maybe_task { + tasks.push(t); + } else { + break; // upstream closed + } + } + + /* run filters before sending to downstream */ + let mut filtered_tasks = Vec::with_capacity(TASK_BUFFER_SIZE); + for mut t in tasks { + if self.revalidate_or_stale(session, &mut t, ctx).await { + serve_from_cache.enable(); + response_state.enable_cached_response(); + // skip downstream filtering entirely as the 304 will not be sent + break; + } + #[cfg(feature = "upstream_modules")] + if let HttpTask::Header(header, end_of_stream) = &t { + self.inner + .adjust_upstream_modules(session, header, *end_of_stream, ctx) + .await?; + } + #[cfg(feature = "upstream_modules")] + session.upstream_modules_filter_task(&mut t).await?; + session.upstream_compression.response_filter(&mut t); + // check error and abort + // otherwise the error is surfaced via write_response_tasks() + if !serve_from_cache.should_send_to_downstream() { + if let HttpTask::Failed(e) = t { + return Err(e); + } + } + filtered_tasks.push( + self.h2_response_filter( + session, + t, + ctx, + serve_from_cache, + range_body_filter, + false, + ) + .await?, + ); + if serve_from_cache.is_miss_header() { + response_state.enable_cached_response(); + } + } + + if !serve_from_cache.should_send_to_downstream() { + // TODO: need to derive response_done from filtered_tasks in case downstream failed already + return Ok(None); + } + + let response_done = session.write_response_tasks(filtered_tasks).await?; + + Ok(Some(response_done)) + } + // returns whether server (downstream) session can be reused async fn bidirection_down_to_up( &self, @@ -322,6 +405,8 @@ where let mut serve_from_cache = ServeFromCache::new(); let mut range_body_filter = proxy_cache::range_filter::RangeBodyFilter::new(); + let mut next_upstream_task: Option = None; + /* duplex mode * see the Same function for h1 for more comments */ @@ -388,57 +473,47 @@ where }; }, - task = rx.recv(), if !response_state.upstream_done() => { + // Handle buffered upstream task from previous iteration + task = async { next_upstream_task.take() }, if next_upstream_task.is_some() => { + debug!("buffered upstream event: {:?}", task); if let Some(t) = task { - debug!("upstream event: {:?}", t); - if serve_from_cache.should_discard_upstream() { - // just drain, do we need to do anything else? - continue; - } - // pull as many tasks as we can - let mut tasks = Vec::with_capacity(TASK_BUFFER_SIZE); - tasks.push(t); - // tokio::task::unconstrained because now_or_never may yield None when the future is ready - while let Some(maybe_task) = tokio::task::unconstrained(rx.recv()).now_or_never() { - if let Some(t) = maybe_task { - tasks.push(t); - } else { - break - } - } - - /* run filters before sending to downstream */ - let mut filtered_tasks = Vec::with_capacity(TASK_BUFFER_SIZE); - for mut t in tasks { - if self.revalidate_or_stale(session, &mut t, ctx).await { - serve_from_cache.enable(); - response_state.enable_cached_response(); - // skip downstream filtering entirely as the 304 will not be sent - break; - } - session.upstream_compression.response_filter(&mut t); - // check error and abort - // otherwise the error is surfaced via write_response_tasks() - if !serve_from_cache.should_send_to_downstream() { - if let HttpTask::Failed(e) = t { - return Err(e); - } - } - filtered_tasks.push( - self.h2_response_filter(session, t, ctx, - &mut serve_from_cache, - &mut range_body_filter, false).await?); - if serve_from_cache.is_miss_header() { - response_state.enable_cached_response(); - } - } - - if !serve_from_cache.should_send_to_downstream() { - // TODO: need to derive response_done from filtered_tasks in case downstream failed already + let Some(response_done) = self.process_upstream_tasks_h2( + session, + ctx, + t, + &mut rx, + &mut serve_from_cache, + &mut range_body_filter, + &mut response_state, + ).await? else { + // nothing sent downstream e.g. serve_from_cache continue; + }; + if session.was_upgraded() { + return Error::e_explain(H2Error, "upgraded while proxying to h2 session"); } + response_state.maybe_set_upstream_done(response_done); + } else { + debug!("empty upstream event"); + response_state.maybe_set_upstream_done(true); + } + }, - let response_done = session.write_response_tasks(filtered_tasks).await?; + task = rx.recv(), if !response_state.upstream_done() && next_upstream_task.is_none() => { + debug!("upstream event: {:?}", task); + if let Some(t) = task { + let Some(response_done) = self.process_upstream_tasks_h2( + session, + ctx, + t, + &mut rx, + &mut serve_from_cache, + &mut range_body_filter, + &mut response_state, + ).await? else { + // nothing sent downstream e.g. serve_from_cache + continue; + }; if session.was_upgraded() { // it is very weird if the downstream session decides to upgrade // since the client h2 session cannot, return an error on this case @@ -449,37 +524,99 @@ where debug!("empty upstream event"); response_state.maybe_set_upstream_done(true); } - } + }, task = serve_from_cache.next_http_task(&mut session.cache, &mut range_body_filter, upgraded), - if !response_state.cached_done() && !downstream_state.is_errored() && serve_from_cache.is_on() => { + if !response_state.cached_done() + && !downstream_state.is_errored() + && serve_from_cache.is_on() + && !session.has_pending_downstream_tasks() => { // backpressure: don't queue if pending writes + let task = self.h2_response_filter(session, task?, ctx, &mut serve_from_cache, &mut range_body_filter, true).await?; debug!("serve_from_cache task {task:?}"); - match session.write_response_tasks(vec![task]).await { - Ok(b) => response_state.maybe_set_cache_done(b), - Err(e) => if serve_from_cache.is_miss() { - // give up writing to downstream but wait for upstream cache write to finish - downstream_state.to_errored(); - response_state.maybe_set_cache_done(true); - warn!( - "Downstream Error ignored during caching: {}, {}", - e, - self.inner.request_summary(session, ctx) - ); - // This will not be treated as a final error, but we should signal to - // downstream session regardless - session.downstream_session.on_proxy_failure(e); - continue; - } else { - return Err(e); + if session.downstream_session.supports_proxy_task_api() { + session.send_downstream_proxy_task(task).await?; + } else { + match session.write_response_tasks(vec![task]).await { + Ok(b) => response_state.maybe_set_cache_done(b), + Err(e) => if serve_from_cache.is_miss() { + // give up writing to downstream but wait for upstream cache write to finish + downstream_state.to_errored(); + response_state.maybe_set_cache_done(true); + warn!( + "Downstream Error ignored during caching: {}, {}", + e, + self.inner.request_summary(session, ctx) + ); + // This will not be treated as a final error, but we should signal to + // downstream session regardless + session.downstream_session.on_proxy_failure(e); + continue; + } else { + return Err(e); + } + } + // A storage error can disable cache between cached_done + // being set and here; see the same guard in proxy_h1.rs. + if response_state.cached_done() && session.cache.enabled() { + if let Err(e) = session.cache.finish_hit_handler().await { + warn!("Error during finish_hit_handler: {}", e); + } } } - if response_state.cached_done() { - if let Err(e) = session.cache.finish_hit_handler().await { - warn!("Error during finish_hit_handler: {}", e); + } + + // Write queued downstream proxy tasks while also polling for upstream tasks. + // This allows cache writes to continue even when downstream is stalled. + // + // "Gate" branch: ready(()) resolves immediately, so the guard controls + // whether we enter. This is not a busy-loop because every path through + // the inner select either (a) drains all pending tasks via + // write_downstream_proxy_tasks (making the guard false), (b) stores an + // upstream task in next_upstream_task (making the guard false), or + // (c) blocks on real I/O inside the nested select. + _ = std::future::ready(()), if session.has_pending_downstream_tasks() && next_upstream_task.is_none() => { + tokio::select! { + // Try to write downstream proxy tasks (cancel-safe) + write_result = session.write_downstream_proxy_tasks() => { + match write_result { + Ok(end) => { + response_state.maybe_set_cache_done(end); + // See disabled() guard comment above. + // See enabled() guard comment above. + if response_state.cached_done() && session.cache.enabled() { + if let Err(e) = session.cache.finish_hit_handler().await { + warn!("Error during finish_hit_handler: {}", e); + } + } + } + Err(e) => if serve_from_cache.is_miss() { + // give up writing to downstream but wait for upstream cache write to finish + downstream_state.to_errored(); + response_state.maybe_set_cache_done(true); + warn!( + "Downstream write error ignored during caching: {}, {}", + e, + self.inner.request_summary(session, ctx) + ); + session.downstream_session.on_proxy_failure(e); + } else { + return Err(e); + } + } + } + + // Also poll for upstream tasks - if we get one, cancel the write and handle it. + upstream_task = rx.recv(), if !response_state.upstream_done() && serve_from_cache.is_on() && next_upstream_task.is_none() => { + if let Some(t) = upstream_task { + next_upstream_task = Some(t); + continue; + } else { + response_state.maybe_set_upstream_done(true); + } } } } diff --git a/pingora-proxy/src/proxy_trait.rs b/pingora-proxy/src/proxy_trait.rs index f4193fca..2411092d 100644 --- a/pingora-proxy/src/proxy_trait.rs +++ b/pingora-proxy/src/proxy_trait.rs @@ -57,6 +57,23 @@ pub trait ProxyHttp { modules.add_module(ResponseCompressionBuilder::enable(0)); } + /// Set up upstream modules. + /// + /// In this phase, users can add [HttpModules] that will process upstream responses + /// **before** `upstream_compression`. This is the correct place to register modules + /// that need to observe the raw (pre-compression) upstream response body, such as + /// a dictionary store for shared dictionary compression. + /// + /// Upstream modules are ordered by [`HttpModuleBuilder::order()`]: higher values run + /// first. They are invoked on each upstream response task (header, body, trailers) + /// before `upstream_compression` processes the task. + /// + /// By default this method does nothing. + /// + /// This method requires the `upstream_modules` feature to be enabled. + #[cfg(feature = "upstream_modules")] + fn init_upstream_modules(&self, _modules: &mut HttpModules) {} + /// Handle the incoming request. /// /// In this phase, users can parse, validate, rate limit, perform access control and/or @@ -293,6 +310,39 @@ pub trait ProxyHttp { Ok(()) } + /// Adjust upstream modules before they process the response header. + /// + /// This filter is called when the upstream response header arrives, before upstream modules + /// (such as `upstream_compression`) run their response header filter. Use this to configure + /// module behavior based on the response, e.g. setting a dictionary for dictionary-based + /// content encoding. + /// + /// This filter may be called more than once per request if the upstream sends informational + /// (1xx) response headers before the final response. Implementations can check + /// [`upstream_response.status.is_informational()`](http::StatusCode::is_informational) to + /// distinguish informational headers from the final response if needed. + /// + /// `end_of_stream` indicates whether the response header is also the end of the response + /// (e.g. for HEAD responses or 304s with no body). + /// + /// The response header is provided as an immutable reference. To modify the response header + /// itself, use [`Self::upstream_response_filter()`] instead. + /// + /// This filter requires the `upstream_modules` feature to be enabled. + #[cfg(feature = "upstream_modules")] + async fn adjust_upstream_modules( + &self, + _session: &mut Session, + _upstream_response: &ResponseHeader, + _end_of_stream: bool, + _ctx: &mut Self::CTX, + ) -> Result<()> + where + Self::CTX: Send + Sync, + { + Ok(()) + } + /// Modify the response header from the upstream /// /// The modification is before caching, so any change here will be stored in the cache if enabled. diff --git a/pingora-proxy/src/subrequest/pipe.rs b/pingora-proxy/src/subrequest/pipe.rs index 6dd4a57e..7845d4dc 100644 --- a/pingora-proxy/src/subrequest/pipe.rs +++ b/pingora-proxy/src/subrequest/pipe.rs @@ -42,16 +42,37 @@ pub enum InputBodyType { SaveBody(usize), } -/// Context struct as a result of subrequest piping. -#[derive(Clone)] +/// Outcome of [`pipe_subrequest`]. +#[derive(Debug, Default)] pub struct PipeSubrequestState { - /// The saved (captured) body from the main session. + /// Captured body from the main session. pub saved_body: Option, + /// Did the subrequest produce a response header? Checked before the task + /// filter runs, so a filtered-out header still counts. + pub header_received: bool, + /// The spawned subrequest task handle. Always set after spawn. Caller is + /// responsible for awaiting/inspecting state. + pub join_handle: Option>, + /// The receiving half of the pipe channel. When the coordinator exits + /// `pipe_subrequest` before the subrequest task finishes writing, this + /// receiver must be kept alive and drained alongside the join handle; + /// otherwise dropping it breaks the pipe and prevents the writer from + /// completing its cache-write lifecycle. + pub pipe_rx: Option>, } impl PipeSubrequestState { - fn new() -> PipeSubrequestState { - PipeSubrequestState { saved_body: None } + /// Creates a snapshot for error reporting, excluding the join handle. + /// Moves `pipe_rx` into the snapshot so the receiver stays alive through + /// the error path and is not dropped when `self` is cleaned up. + /// Used by [`map_pipe_err`] to capture state at the point of failure. + pub fn snapshot_for_error(&mut self) -> Self { + PipeSubrequestState { + saved_body: self.saved_body.clone(), + header_received: self.header_received, + join_handle: None, + pipe_rx: self.pipe_rx.take(), + } } } @@ -79,9 +100,9 @@ impl PipeSubrequestError { fn map_pipe_err>>( result: Result, from_subreq: bool, - state: &PipeSubrequestState, + state: &mut PipeSubrequestState, ) -> Result { - result.map_err(|e| PipeSubrequestError::new(e, from_subreq, state.clone())) + result.map_err(|e| PipeSubrequestError::new(e, from_subreq, state.snapshot_for_error())) } #[derive(Debug, Clone)] @@ -182,13 +203,13 @@ where }; let mut downstream_state = DownstreamStateMachine::new(no_body_input); - let mut state = PipeSubrequestState::new(); - state.saved_body = saved_body; + let mut state = PipeSubrequestState { + saved_body, + ..Default::default() + }; - // Have the subrequest remove all body-related headers if no body will be sent - // TODO: we could also await the join handle, but subrequest may be running logging phase - // also the full run() may also await cache fill if downstream fails - let _join_handle = tokio::spawn(async move { + // Remove headers if no body. + let join_handle = tokio::spawn(async move { if no_body_input { subrequest .session_mut() @@ -196,10 +217,14 @@ where .expect("PreparedSubrequest must be subrequest") .clear_request_body_headers(); } - subrequest.run().await + let _ = subrequest.run().await; }); + state.join_handle = Some(join_handle); let tx = subrequest_handle.tx; - let mut rx = subrequest_handle.rx; + // Move rx into state immediately so it survives all exit paths (early `?` + // returns, errors, and the normal success path). The select loop borrows it + // back via `state.pipe_rx.as_mut().expect(...)`. + state.pipe_rx = Some(subrequest_handle.rx); let mut wants_body = false; let mut wants_body_rx_err = false; @@ -216,20 +241,27 @@ where .or_err(InternalError, "try_reserve() body pipe for subrequest"); tokio::select! { - task = rx.recv(), if !response_state.upstream_done() => { + task = state.pipe_rx.as_mut().expect("pipe_rx always set after spawn").recv(), if !response_state.upstream_done() => { debug!("upstream event: {:?}", task); if let Some(t) = task { + // Did the subrequest get headers? + if matches!(&t, HttpTask::Header(..)) { + state.header_received = true; + } // pull as many tasks as we can const TASK_BUFFER_SIZE: usize = 4; let mut tasks = Vec::with_capacity(TASK_BUFFER_SIZE); - let task = map_pipe_err(task_filter(t), false, &state)?; + let task = map_pipe_err(task_filter(t), false, &mut state)?; if let Some(filtered) = task { tasks.push(filtered); } // tokio::task::unconstrained because now_or_never may yield None when the future is ready - while let Some(maybe_task) = tokio::task::unconstrained(rx.recv()).now_or_never() { + while let Some(maybe_task) = tokio::task::unconstrained(state.pipe_rx.as_mut().expect("pipe_rx always set after spawn").recv()).now_or_never() { if let Some(t) = maybe_task { - let task = map_pipe_err(task_filter(t), false, &state)?; + if matches!(&t, HttpTask::Header(..)) { + state.header_received = true; + } + let task = map_pipe_err(task_filter(t), false, &mut state)?; if let Some(filtered) = task { tasks.push(filtered); } @@ -239,7 +271,7 @@ where } // FIXME: if one of these tasks is Failed(e), the session will return that // error; in this case, the error is actually from the subreq - let response_done = map_pipe_err(session.write_response_tasks(tasks).await, false, &state)?; + let response_done = map_pipe_err(session.write_response_tasks(tasks).await, false, &mut state)?; // NOTE: technically it is the downstream whose response state has finished here // we consider the subrequest's work done however @@ -248,9 +280,7 @@ where // (can only happen with a real session, TODO to allow with preset body) downstream_state.maybe_finished(!use_preset_body && session.is_body_done()); } else { - // quite possible that the subrequest may be finished, though the main session - // is not - we still must exit in this case - debug!("empty upstream event"); + debug!("upstream channel closed early"); response_state.maybe_set_upstream_done(true); } }, @@ -291,7 +321,7 @@ where // this is the first subrequest // send the body debug!("downstream event: main body for subrequest"); - let body = map_pipe_err(body.map_err(|e| e.into_down()), false, &state)?; + let body = map_pipe_err(body.map_err(|e| e.into_down()), false, &mut state)?; // If the request is websocket, `None` body means the request is closed. // Set the response to be done as well so that the request completes normally. @@ -307,7 +337,7 @@ where state.saved_body.as_mut(), send_permit.expect("checked is_ok()"), ) - .await, false, &state)?; + .await, false, &mut state)?; downstream_state.maybe_finished(request_done); @@ -328,7 +358,7 @@ where is_body_done, None, send_permit.expect("checked is_ok()"), - ), false, &state)?; + ), false, &mut state)?; downstream_state.maybe_finished(request_done); }, @@ -397,3 +427,64 @@ fn do_send_body_to_pipe( Ok(end_of_body) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::subrequest::Ctx as SubrequestCtx; + use crate::{Session, Subrequest, SubrequestSpawner}; + use async_trait::async_trait; + use pingora_core::protocols::http::ServerSession as HttpSession; + use std::sync::Arc; + + /// Drops session without producing output — channels close, rx returns None. + struct NoopApp; + + #[async_trait] + impl Subrequest for NoopApp { + async fn process_subrequest( + self: Arc, + _session: Box, + _ctx: Box, + ) { + } + } + + async fn mock_session() -> Session { + let input = b"GET / HTTP/1.1\r\nHost: test\r\n\r\n"; + let mock_io = tokio_test::io::Builder::new().read(&input[..]).build(); + let mut session = Session::new_h1(Box::new(mock_io) as pingora_core::protocols::Stream); + session + .downstream_session + .read_request() + .await + .expect("mock request should parse"); + session + } + + #[tokio::test] + async fn no_header_received_when_subrequest_exits_silently() { + let mut session = mock_session().await; + + let spawner = SubrequestSpawner::new(Arc::new(NoopApp)); + let ctx = SubrequestCtx::builder().body_mode(BodyMode::NoBody).build(); + let (subrequest, handle) = spawner.create_subrequest(session.as_downstream(), ctx); + + let result = pipe_subrequest( + &mut session, + subrequest, + handle, + |task| Ok(Some(task)), + InputBodyType::Preset(InputBody::NoBody), + ) + .await; + + let state = + result.unwrap_or_else(|e| panic!("pipe should return Ok, not Err: {:?}", e.error)); + assert!( + !state.header_received, + "no header should have been received from the no-op subrequest" + ); + assert!(state.join_handle.is_some(), "task handle should be set"); + } +} diff --git a/pingora-proxy/tests/test_basic.rs b/pingora-proxy/tests/test_basic.rs index 77303fc3..cc48cb42 100644 --- a/pingora-proxy/tests/test_basic.rs +++ b/pingora-proxy/tests/test_basic.rs @@ -17,7 +17,8 @@ mod utils; use bytes::Bytes; use h2::client; use http::Request; -use hyper::{body::HttpBody, header::HeaderValue, Body, Client}; +use http_body_util::BodyExt; +use hyper_util::client::legacy::Client; #[cfg(unix)] use hyperlocal::{UnixClientExt, Uri}; use reqwest::{header, StatusCode}; @@ -161,21 +162,21 @@ async fn test_h2_to_h2() { async fn test_h2c_to_h2c() { init(); - let client = hyper::client::Client::builder() + let client = hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()) .http2_only(true) - .build_http(); + .build_http::>(); - let mut req = hyper::Request::builder() + let mut req = http::Request::builder() .uri("http://127.0.0.1:6146") - .body(Body::empty()) + .body(http_body_util::Empty::::new()) .unwrap(); req.headers_mut() - .insert("x-h2", HeaderValue::from_bytes(b"true").unwrap()); + .insert("x-h2", http::HeaderValue::from_bytes(b"true").unwrap()); let res = client.request(req).await.unwrap(); assert_eq!(res.status(), reqwest::StatusCode::OK); assert_eq!(res.version(), reqwest::Version::HTTP_2); - let body = res.into_body().data().await.unwrap().unwrap(); + let body = res.into_body().collect().await.unwrap().to_bytes(); assert_eq!(body.as_ref(), b"Hello World!\n"); } @@ -183,21 +184,21 @@ async fn test_h2c_to_h2c() { async fn test_h1_on_h2c_port() { init(); - let client = hyper::client::Client::builder() + let client = hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()) .http2_only(false) - .build_http(); + .build_http::>(); - let mut req = hyper::Request::builder() + let mut req = http::Request::builder() .uri("http://127.0.0.1:6146") - .body(Body::empty()) + .body(http_body_util::Empty::::new()) .unwrap(); req.headers_mut() - .insert("x-h2", HeaderValue::from_bytes(b"true").unwrap()); + .insert("x-h2", http::HeaderValue::from_bytes(b"true").unwrap()); let res = client.request(req).await.unwrap(); assert_eq!(res.status(), reqwest::StatusCode::OK); assert_eq!(res.version(), reqwest::Version::HTTP_11); - let body = res.into_body().data().await.unwrap().unwrap(); + let body = res.into_body().collect().await.unwrap().to_bytes(); assert_eq!(body.as_ref(), b"Hello World!\n"); } @@ -303,7 +304,7 @@ async fn test_h2_head() { async fn test_simple_proxy_uds() { init(); let url = Uri::new("/tmp/pingora_proxy.sock", "/").into(); - let client = Client::unix(); + let client: Client> = Client::unix(); let res = client.get(url).await.unwrap(); @@ -324,7 +325,10 @@ async fn test_simple_proxy_uds() { assert_eq!(sockaddr.ip().to_string(), "127.0.0.2"); assert!(is_specified_port(sockaddr.port())); - let body = hyper::body::to_bytes(body).await.unwrap(); + let body = http_body_util::BodyExt::collect(body) + .await + .unwrap() + .to_bytes(); assert_eq!(body.as_ref(), b"Hello World!\n"); } diff --git a/pingora-proxy/tests/test_upstream.rs b/pingora-proxy/tests/test_upstream.rs index b22a1ead..862009e5 100644 --- a/pingora-proxy/tests/test_upstream.rs +++ b/pingora-proxy/tests/test_upstream.rs @@ -17,9 +17,11 @@ mod utils; use utils::server_utils::init; use utils::websocket::{WS_ECHO, WS_ECHO_RAW}; +use bytes::Bytes; use futures::{SinkExt, StreamExt}; +use http::header::{HeaderName, HeaderValue}; +use http_body_util::BodyExt; use pingora_http::ResponseHeader; -use reqwest::header::{HeaderName, HeaderValue}; use reqwest::{StatusCode, Version}; use std::time::{Duration, Instant}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -70,6 +72,25 @@ async fn test_connection_die() { assert!(body.is_err()); } +// This test is ignored because it has a fundamental timing dependency. +// +// The nginx origin sends a 200 response and flushes it, then sleeps 1s +// and kills the connection (RST). The test expects the client to always +// receive the 200 before the connection dies. +// +// This fails under CI load because the 15MB request body takes longer +// than 1s to write. The proxy's select! loop is busy writing body chunks +// upstream and can't read the 200 response concurrently. When the 1s +// expires and nginx sends a TCP RST, the RST discards all buffered data +// (including the 200) per TCP semantics. The proxy then sees an upstream +// error and resets the client connection. +// +// The underlying issue is that TCP RST discards unread buffered data, +// so the 200 response is lost even though it was sent before the RST. +// Fixing this would require the proxy to read the response before or +// concurrently with the body write completing, which is a deeper +// architectural change. +#[ignore] #[tokio::test] async fn test_upload_connection_die() { init(); @@ -181,7 +202,7 @@ async fn test_ws_server_ends_conn() { ws_stream.close(None).await.unwrap(); let msg = ws_stream.next().await.unwrap().unwrap(); // assert echo - assert_eq!("test", msg.into_text().unwrap()); + assert_eq!(msg.into_text().unwrap(), "test"); let msg = ws_stream.next().await.unwrap().unwrap(); // assert graceful close assert!(matches!(msg, Message::Close(None))); @@ -363,22 +384,22 @@ async fn test_upgrade_body_after_101() { #[tokio::test] async fn test_download_timeout() { init(); - use hyper::body::HttpBody; use tokio::time::sleep; - let client = hyper::Client::new(); - let uri: hyper::Uri = "http://127.0.0.1:6147/download_large/".parse().unwrap(); - let req = hyper::Request::builder() + let client = hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()) + .build_http::>(); + let uri: http::Uri = "http://127.0.0.1:6147/download_large/".parse().unwrap(); + let req = http::Request::builder() .uri(uri) .header("x-write-timeout", "1") - .body(hyper::Body::empty()) + .body(http_body_util::Empty::::new()) .unwrap(); let mut res = client.request(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); let mut err = false; sleep(Duration::from_secs(2)).await; - while let Some(chunk) = res.body_mut().data().await { + while let Some(chunk) = res.body_mut().frame().await { if chunk.is_err() { err = true; } @@ -389,28 +410,27 @@ async fn test_download_timeout() { #[tokio::test] async fn test_download_timeout_min_rate() { init(); - use hyper::body::HttpBody; use tokio::time::sleep; - let client = hyper::Client::new(); - let uri: hyper::Uri = "http://127.0.0.1:6147/download/".parse().unwrap(); - let req = hyper::Request::builder() + let client = hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()) + .build_http::>(); + let uri: http::Uri = "http://127.0.0.1:6147/download/".parse().unwrap(); + let req = http::Request::builder() .uri(uri) .header("x-write-timeout", "1") .header("x-min-rate", "10000") - .body(hyper::Body::empty()) + .body(http_body_util::Empty::::new()) .unwrap(); let mut res = client.request(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); let mut err = false; sleep(Duration::from_secs(2)).await; - while let Some(chunk) = res.body_mut().data().await { + while let Some(chunk) = res.body_mut().frame().await { if chunk.is_err() { err = true; } } - // no error as write timeout is overridden by min rate assert!(!err); } @@ -482,8 +502,9 @@ async fn test_h2_upstream_no_end_stream_read_timeout() { } }); - tokio::time::sleep(Duration::from_millis(50)).await; - + // The listener was bound before the spawn (line 426), so the kernel + // is already accepting connections into the backlog. No readiness + // wait needed. let client = reqwest::Client::new(); let url = "http://127.0.0.1:6147/test"; @@ -532,6 +553,74 @@ async fn test_h2_upstream_no_end_stream_read_timeout() { } } +/// Mock origin that sends 100 Continue then a final response for any request. +/// Returns the port the server is listening on. +async fn mock_100_continue_server() -> u16 { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + + tokio::spawn(async move { + if let Ok((mut stream, _addr)) = listener.accept().await { + // Read the request (just drain it) + let mut buf = [0u8; 4096]; + let _ = stream.read(&mut buf).await.unwrap(); + + // Send 100 Continue + stream + .write_all(b"HTTP/1.1 100 Continue\r\n\r\n") + .await + .unwrap(); + // Small delay so the client reads the 100 separately + tokio::time::sleep(Duration::from_millis(100)).await; + + // Send final 200 OK with Content-Length (but no body for HEAD) + stream + .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 42\r\n\r\n") + .await + .unwrap(); + tokio::time::sleep(Duration::from_millis(100)).await; + } + }); + + port +} + +#[tokio::test] +async fn test_head_with_100_continue() { + init(); + + let port = mock_100_continue_server().await; + + let mut stream = TcpStream::connect("127.0.0.1:6147").await.unwrap(); + stream + .write_all( + format!("HEAD / HTTP/1.1\r\nHost: localhost\r\nx-port: {port}\r\n\r\n").as_bytes(), + ) + .await + .unwrap(); + + // Read through any 1xx until we get the final (non-1xx) response + let result = timeout(Duration::from_secs(5), async { + let mut resp; + let mut body; + loop { + (resp, body) = read_response_header(&mut stream).await; + if resp.status.as_u16() >= 200 { + return (resp, body); + } + } + }) + .await + .expect("should not time out waiting for final response"); + + let (resp, body) = result; + assert_eq!(resp.status.as_u16(), 200); + // HEAD responses have no body even with Content-Length + assert!(body.is_empty(), "HEAD response should have no body"); +} + mod test_cache { use super::*; use std::str::FromStr; @@ -1311,35 +1400,37 @@ mod test_cache { // set up a one-off mock server // (warp / hyper don't have custom 1xx sending capabilities yet) - async fn mock_1xx_server(port: u16, cc_header: &str) { - use tokio::io::AsyncWriteExt; - - let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) - .await - .unwrap(); - if let Ok((mut stream, _addr)) = listener.accept().await { - stream.write_all(b"HTTP/1.1 103 Early Hints\r\nLink: ; rel=preconnect\r\n\r\n").await.unwrap(); - // wait a bit so that the client can read - sleep(Duration::from_millis(100)).await; - stream.write_all(format!("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nCache-Control: {}\r\n\r\nhello", cc_header).as_bytes()).await.unwrap(); - sleep(Duration::from_millis(100)).await; - } + // One-shot mock server that sends a 103 Early Hints then a final 200. + // Binds to port 0 (OS-assigned) and returns the actual port via a + // oneshot channel once the listener is ready. + fn spawn_mock_1xx_server(cc_header: &'static str) -> tokio::sync::oneshot::Receiver { + let (tx, rx) = tokio::sync::oneshot::channel(); + tokio::spawn(async move { + use tokio::io::AsyncWriteExt; + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let _ = tx.send(port); // signal: port is bound + if let Ok((mut stream, _addr)) = listener.accept().await { + stream.write_all(b"HTTP/1.1 103 Early Hints\r\nLink: ; rel=preconnect\r\n\r\n").await.unwrap(); + sleep(Duration::from_millis(100)).await; + stream.write_all(format!("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nCache-Control: {}\r\n\r\nhello", cc_header).as_bytes()).await.unwrap(); + sleep(Duration::from_millis(100)).await; + } + }); + rx } init(); let url = "http://127.0.0.1:6148/unique/test_1xx_caching"; - tokio::spawn(async { - mock_1xx_server(6151, "max-age=5").await; - }); - // wait for server to start - sleep(Duration::from_millis(100)).await; + let port = spawn_mock_1xx_server("max-age=5").await.unwrap(); let client = reqwest::Client::new(); let res = client .get(url) - .header("x-port", "6151") + .header("x-port", port.to_string()) .send() .await .unwrap(); @@ -1348,9 +1439,10 @@ mod test_cache { assert_eq!(headers["x-cache-status"], "miss"); assert_eq!(res.text().await.unwrap(), "hello"); + // Second request to the same URL should be a cache hit (no server needed) let res = client .get(url) - .header("x-port", "6151") + .header("x-port", port.to_string()) .send() .await .unwrap(); @@ -1362,15 +1454,11 @@ mod test_cache { // 1xx shouldn't interfere with bypass let url = "http://127.0.0.1:6148/unique/test_1xx_bypass"; - tokio::spawn(async { - mock_1xx_server(6152, "private, no-store").await; - }); - // wait for server to start - sleep(Duration::from_millis(100)).await; + let port = spawn_mock_1xx_server("private, no-store").await.unwrap(); let res = client .get(url) - .header("x-port", "6152") + .header("x-port", port.to_string()) .send() .await .unwrap(); @@ -1380,16 +1468,11 @@ mod test_cache { assert_eq!(res.text().await.unwrap(), "hello"); // restart the one-off server - still uncacheable - sleep(Duration::from_millis(100)).await; - tokio::spawn(async { - mock_1xx_server(6152, "private, no-store").await; - }); - // wait for server to start - sleep(Duration::from_millis(100)).await; + let port = spawn_mock_1xx_server("private, no-store").await.unwrap(); let res = client .get(url) - .header("x-port", "6152") + .header("x-port", port.to_string()) .send() .await .unwrap(); @@ -2850,6 +2933,154 @@ mod test_cache { assert_eq!(res.text().await.unwrap(), "hello world"); } + #[tokio::test] + #[ignore = "flaky in CI due to timing/resource contention"] + async fn test_caching_when_downstream_stalls() { + use std::net::ToSocketAddrs; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpStream; + + init(); + let url = "http://127.0.0.1:6148/unique/test_caching_when_downstream_stalls/download/"; + + // Connection 1: read 10KiB then stall, holding the cache lock while + // the proxy populates cache from upstream. + let slow_task = tokio::spawn(async move { + let addr = "127.0.0.1:6148".to_socket_addrs().unwrap().next().unwrap(); + let mut stream = TcpStream::connect(&addr).await.unwrap(); + + let request = concat!( + "GET /unique/test_caching_when_downstream_stalls/download/ HTTP/1.1\r\n", + "Host: 127.0.0.1:6148\r\n", + "x-lock: true\r\n", + "x-set-cache-control: public, max-age=60\r\n", + "\r\n", + ); + stream.write_all(request.as_bytes()).await.unwrap(); + + let mut buf = [0; 10 * 1024]; + let mut b = &mut buf[..]; + while !b.is_empty() { + let n = stream.read(b).await.unwrap(); + b = &mut b[n..] + } + + // Hold the stalled connection open long enough + sleep(Duration::from_secs(10)).await; + }); + + // Give connection 1 time to acquire the cache lock + sleep(Duration::from_secs(1)).await; + + // Connection 2: should get a cache hit once the proxy finishes + // populating cache from upstream (independent of stall). + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(url) + .header("x-lock", "true") + .header("x-set-cache-control", "public, max-age=60") + .timeout(Duration::from_secs(8)) + .send() + .await + .unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + let headers = res.headers(); + assert_eq!(headers["x-cache-status"], "hit"); + + // If the cache was populated fast enough (before connection 2 arrived), + // there is no lock contention and x-cache-lock-time-ms is absent. + // If there was contention, the wait should be short. + if let Some(lock_ms) = headers.get("x-cache-lock-time-ms") { + let ms: u64 = lock_ms.to_str().unwrap().parse().unwrap(); + assert!( + ms < 2000, + "lock wait {ms}ms should be well under the 2s timeout" + ); + } + + assert_eq!( + res.text().await.unwrap(), + String::from("A").repeat(4 * 1024 * 1024) + ); + + let elapsed = start.elapsed(); + assert!( + elapsed < Duration::from_secs(5), + "second request took {elapsed:?}, should be fast" + ); + + // Don't wait for the slow connection + slow_task.abort(); + } + + // Same as test_caching_when_downstream_stalls but the proxy connects + // to the origin over H2 (via the x-h2 header). + // + #[tokio::test] + #[ignore = "flaky in CI due to timing/resource contention"] + async fn test_caching_h2_upstream_when_downstream_stalls() { + use std::net::ToSocketAddrs; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpStream; + + init(); + let url = "http://127.0.0.1:6148/unique/test_caching_h2_upstream_when_downstream_stalls/download/"; + + let slow_task = tokio::spawn(async move { + let addr = "127.0.0.1:6148".to_socket_addrs().unwrap().next().unwrap(); + let mut stream = TcpStream::connect(&addr).await.unwrap(); + + let request = concat!( + "GET /unique/test_caching_h2_upstream_when_downstream_stalls/download/ HTTP/1.1\r\n", + "Host: 127.0.0.1:6148\r\n", + "x-h2: true\r\n", + "x-lock: true\r\n", + "x-set-cache-control: public, max-age=60\r\n", + "\r\n", + ); + stream.write_all(request.as_bytes()).await.unwrap(); + + let mut buf = [0; 10 * 1024]; + let mut b = &mut buf[..]; + while !b.is_empty() { + let n = stream.read(b).await.unwrap(); + b = &mut b[n..] + } + + sleep(Duration::from_secs(10)).await; + }); + + sleep(Duration::from_secs(1)).await; + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(url) + .header("x-h2", "true") + .header("x-lock", "true") + .header("x-set-cache-control", "public, max-age=60") + .timeout(Duration::from_secs(8)) + .send() + .await + .unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + let headers = res.headers(); + assert_eq!(headers["x-cache-status"], "hit"); + assert_eq!( + res.text().await.unwrap(), + String::from("A").repeat(4 * 1024 * 1024) + ); + + let elapsed = start.elapsed(); + assert!( + elapsed < Duration::from_secs(5), + "second request took {elapsed:?}, should be fast (upstream-speed-bound)" + ); + + slow_task.abort(); + } + async fn send_vary_req_with_headers_with_dups( url: &str, vary_field: &str, @@ -3473,4 +3704,145 @@ mod test_cache { assert_eq!(headers["x-cache-status"], "hit"); assert_eq!(res.text().await.unwrap(), "hello world"); } + + // Ignored until H2 downstream gets the proxy task API + // (write_response_tasks blocks on flow control today). + // multi_thread needed for h2 connection driver tasks. + #[tokio::test(flavor = "multi_thread")] + #[ignore] + async fn test_cache_h2_downstream_stalls() { + init(); + + use h2::client; + use http::Request; + use tokio::net::TcpStream; + use tokio::time::{timeout, Duration}; + + // Step 1: Connection 1 - Open h2 connection to h2c cache proxy (port 6154) and STALL + let tcp1 = TcpStream::connect("127.0.0.1:6154").await.unwrap(); + let (mut h2_client1, h2_conn1) = client::handshake(tcp1).await.unwrap(); + + tokio::spawn(async move { + if let Err(e) = h2_conn1.await { + eprintln!("H2 connection 1 error: {:?}", e); + } + }); + + // Request the cached resource on connection 1 + let request1 = Request::builder() + .uri("http://127.0.0.1/unique/test_h2_stall/download/") + .body(()) + .unwrap(); + + let (response1, _) = h2_client1.send_request(request1, true).unwrap(); + let response1 = response1.await.unwrap(); + assert_eq!(response1.status(), 200); + assert_eq!(response1.headers()["x-cache-status"], "miss"); + + let mut body1 = response1.into_body(); + + // Read first chunk but don't release flow control to stall connection 1 + let first_chunk = body1.data().await.unwrap().unwrap(); + assert!(!first_chunk.is_empty()); + + // Connection 2 - While conn 1 is stalled, try to get the same cached resource + let tcp2 = TcpStream::connect("127.0.0.1:6154").await.unwrap(); + let (mut h2_client2, h2_conn2) = client::handshake(tcp2).await.unwrap(); + + tokio::spawn(async move { + if let Err(e) = h2_conn2.await { + eprintln!("H2 connection 2 error: {:?}", e); + } + }); + + let request2 = Request::builder() + .uri("http://127.0.0.1/unique/test_h2_stall/download/") + .body(()) + .unwrap(); + + let (response2, _) = h2_client2.send_request(request2, true).unwrap(); + + // Try to read, proxy should not be blocked + let response2 = match timeout(Duration::from_secs(5), response2).await { + Ok(Ok(resp)) => resp, + Ok(Err(e)) => panic!("Connection 2 failed: {:?}", e), + Err(_) => panic!("Connection 2 timed out - proxy blocked without proxy task API!"), + }; + + assert_eq!(response2.status(), 200); + assert_eq!(response2.headers()["x-cache-status"], "hit"); + + // Read full response from connection 2 + let mut body2 = response2.into_body(); + let mut received2 = Vec::new(); + while let Some(Ok(chunk)) = timeout(Duration::from_secs(5), body2.data()) + .await + .expect("should not time out waiting for data") + { + let len = chunk.len(); + received2.extend_from_slice(&chunk); + body2.flow_control().release_capacity(len).unwrap(); + } + + assert_eq!( + received2.len(), + 4 * 1024 * 1024, + "Connection 2 should receive full cached response" + ); + + // Clean up: unstall connection 1 + body1 + .flow_control() + .release_capacity(first_chunk.len()) + .unwrap(); + } + + // Test cache population from H2 upstream origin with H1 downstream. + #[tokio::test] + async fn test_cache_upstream_h2_downstream_h1() { + init(); + + let test_url = "http://127.0.0.1:6148/unique/test_h2_upstream/download/"; + + // Step 1: Populate cache from H2 origin (cache miss) + let client = reqwest::Client::new(); + let res = client + .get(test_url) + .header("x-h2", "true") + .header("x-lock", "true") + .header("x-set-cache-control", "public, max-age=60") + .send() + .await + .unwrap(); + + assert_eq!(res.status(), 200); + assert_eq!(res.headers()["x-cache-status"], "miss"); + assert_eq!(res.headers()["origin-http2"], "h2c"); + + let body = res.bytes().await.unwrap(); + assert_eq!( + body.len(), + 4 * 1024 * 1024, + "Should receive full 4MB response" + ); + + // Step 2: Request again and verify cache hit + let res = client + .get(test_url) + .header("x-h2", "true") + .header("x-set-cache-control", "public, max-age=60") + .send() + .await + .unwrap(); + + assert_eq!(res.status(), 200); + assert_eq!(res.headers()["x-cache-status"], "hit"); + + let body = res.bytes().await.unwrap(); + assert_eq!( + body.len(), + 4 * 1024 * 1024, + "Should receive full 4MB from cache" + ); + } } diff --git a/pingora-proxy/tests/utils/conf/origin/conf/nginx.conf b/pingora-proxy/tests/utils/conf/origin/conf/nginx.conf index f19c974c..969695eb 100644 --- a/pingora-proxy/tests/utils/conf/origin/conf/nginx.conf +++ b/pingora-proxy/tests/utils/conf/origin/conf/nginx.conf @@ -311,7 +311,6 @@ http { location /download/ { content_by_lua_block { - ngx.req.read_body() local body = string.rep("A", 4194304) ngx.header["Content-Length"] = #body ngx.print(body) diff --git a/pingora-proxy/tests/utils/mock_origin.rs b/pingora-proxy/tests/utils/mock_origin.rs index 74840e19..fa5a327f 100644 --- a/pingora-proxy/tests/utils/mock_origin.rs +++ b/pingora-proxy/tests/utils/mock_origin.rs @@ -59,7 +59,22 @@ fn init() -> bool { .output() .unwrap(); }); - // wait until the server is up - thread::sleep(time::Duration::from_secs(2)); - true + // Wait until openresty is accepting connections, then give it a moment + // to finish worker initialization. + let deadline = time::Instant::now() + time::Duration::from_secs(10); + while time::Instant::now() < deadline { + if std::net::TcpStream::connect_timeout( + &"127.0.0.1:8000".parse().unwrap(), + time::Duration::from_millis(100), + ) + .is_ok() + { + // Port is listening; allow a brief window for workers to finish + // initializing before tests start sending real requests. + thread::sleep(time::Duration::from_millis(500)); + return true; + } + thread::sleep(time::Duration::from_millis(50)); + } + panic!("mock origin (openresty) failed to start within 10s"); } diff --git a/pingora-proxy/tests/utils/server_utils.rs b/pingora-proxy/tests/utils/server_utils.rs index 5a418934..0dccb6dd 100644 --- a/pingora-proxy/tests/utils/server_utils.rs +++ b/pingora-proxy/tests/utils/server_utils.rs @@ -253,8 +253,19 @@ impl ProxyHttp for ExampleProxyHttp { session: &mut Session, _ctx: &mut Self::CTX, ) -> Result<()> { - let req = session.req_header(); - let downstream_compression = req.headers.get("x-downstream-compression").is_some(); + let proxy_tasks_enabled = session + .req_header() + .headers + .get("x-proxy-tasks-enabled") + .is_some(); + if proxy_tasks_enabled { + session.downstream_session.set_proxy_tasks_enabled(true); + } + let downstream_compression = session + .req_header() + .headers + .get("x-downstream-compression") + .is_some(); if downstream_compression { session .downstream_modules_ctx @@ -658,6 +669,12 @@ impl ProxyHttp for ExampleProxyCache { upstream_response.remove_header(&CONTENT_LENGTH); upstream_response.remove_header(&TRANSFER_ENCODING); } + // Allow tests to inject Cache-Control into the upstream response + if let Some(cc) = session.req_header().headers.get("x-set-cache-control") { + upstream_response + .insert_header(http::header::CACHE_CONTROL, cc) + .unwrap(); + } Ok(()) } @@ -823,6 +840,15 @@ fn test_main() { pingora_proxy::http_proxy_service(&my_server.configuration, ExampleProxyCache {}); proxy_service_cache.add_tcp("0.0.0.0:6148"); + // H2C-enabled cache proxy on port 6154 + let mut proxy_service_cache_h2c = + pingora_proxy::http_proxy_service(&my_server.configuration, ExampleProxyCache {}); + let cache_h2c_logic = proxy_service_cache_h2c.app_logic_mut().unwrap(); + let mut cache_h2c_options = HttpServerOptions::default(); + cache_h2c_options.h2c = true; + cache_h2c_logic.server_options = Some(cache_h2c_options); + proxy_service_cache_h2c.add_tcp("0.0.0.0:6154"); + #[cfg(feature = "any_tls")] { let cert_path = format!("{}/tests/keys/server.crt", env!("CARGO_MANIFEST_DIR")); @@ -839,6 +865,7 @@ fn test_main() { Box::new(proxy_service_http), Box::new(proxy_service_http_connect), Box::new(proxy_service_cache), + Box::new(proxy_service_cache_h2c), ]; if let Some(proxy_service_https) = proxy_service_https_opt { @@ -873,16 +900,28 @@ pub struct PskTlsServer { #[cfg(feature = "s2n")] impl PskTlsServer { pub fn start() -> Self { - let server_handle = thread::spawn(|| { + use std::sync::mpsc; + use std::time::Duration; + + // Use a channel to wait for the server to bind its port. + // A TCP probe can't be used here because the TLS acceptor would + // try to handshake the probe connection, fail, and panic. + let (tx, rx) = mpsc::channel(); + let server_handle = thread::spawn(move || { let rt = tokio::runtime::Runtime::new().unwrap(); - rt.block_on(Self::run_server()); + rt.block_on(Self::run_server(tx)); }); + + // Wait up to 10s for the server to signal it has bound the port. + rx.recv_timeout(Duration::from_secs(10)) + .expect("PSK TLS server failed to start within 10s"); + PskTlsServer { handle: server_handle, } } - async fn run_server() { + async fn run_server(ready_tx: std::sync::mpsc::Sender<()>) { use pingora_core::{protocols::tls::S2NConnectionBuilder, tls::TlsAcceptor}; use pingora_core::{ protocols::tls::{Psk, PskConfig, PskType}, @@ -899,6 +938,8 @@ impl PskTlsServer { let addr: std::net::SocketAddr = "127.0.0.1:6151".parse().unwrap(); let listener = TcpListener::bind(addr).await.unwrap(); + let _ = ready_tx.send(()); // signal: port is bound + let mut config_builder = Config::builder(); unsafe { config_builder.disable_x509_verification(); @@ -915,12 +956,20 @@ impl PskTlsServer { let acceptor = TlsAcceptor::new(connection_builder); loop { - use tokio::{io::AsyncWriteExt, net::tcp}; + use tokio::io::AsyncWriteExt; let (tcp_stream, _) = listener.accept().await.unwrap(); - let mut stream = acceptor.clone().accept(tcp_stream).await.unwrap(); + // Don't panic on handshake failure — a stale connection or probe + // shouldn't take down the server for subsequent real connections. + let mut stream = match acceptor.clone().accept(tcp_stream).await { + Ok(s) => s, + Err(e) => { + log::warn!("PSK TLS server: handshake failed: {e}"); + continue; + } + }; let response = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello"; - stream.write_all(response).await.unwrap(); - stream.shutdown().await; + let _ = stream.write_all(response).await; + let _ = stream.shutdown().await; } } } diff --git a/pingora-rustls/Cargo.toml b/pingora-rustls/Cargo.toml index efa377bf..51cd00ff 100644 --- a/pingora-rustls/Cargo.toml +++ b/pingora-rustls/Cargo.toml @@ -18,7 +18,7 @@ path = "src/lib.rs" log = "0.4.21" pingora-error = { version = "0.8.0", path = "../pingora-error"} ring = "0.17.12" -rustls = "0.23.12" +rustls = { version = "0.23.12", features = ["ring"] } rustls-native-certs = "0.7.1" rustls-pemfile = "2.1.2" rustls-pki-types = "1.7.0" diff --git a/pingora-rustls/src/lib.rs b/pingora-rustls/src/lib.rs index 097a8da5..deb0c88b 100644 --- a/pingora-rustls/src/lib.rs +++ b/pingora-rustls/src/lib.rs @@ -28,9 +28,19 @@ use pingora_error::{Error, ErrorType, OrErr, Result}; pub use rustls::server::danger::{ClientCertVerified, ClientCertVerifier}; pub use rustls::server::{ClientCertVerifierBuilder, WebPkiClientVerifier}; pub use rustls::{ - client::WebPkiServerVerifier, version, CertificateError, ClientConfig, DigitallySignedStruct, - Error as RusTlsError, KeyLogFile, RootCertStore, ServerConfig, SignatureScheme, Stream, + client::WebPkiServerVerifier, crypto::CryptoProvider, version, CertificateError, ClientConfig, + DigitallySignedStruct, Error as RusTlsError, KeyLogFile, RootCertStore, ServerConfig, + SignatureScheme, Stream, }; + +/// Install the default `ring` CryptoProvider for rustls. +/// +/// rustls 0.23+ requires an explicit provider. This function installs `ring` +/// as the process-level default. Safe to call multiple times — subsequent +/// calls are no-ops. +pub fn install_default_crypto_provider() { + let _ = CryptoProvider::install_default(rustls::crypto::ring::default_provider()); +} pub use rustls_native_certs::load_native_certs; use rustls_pemfile::Item; pub use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime}; diff --git a/pingora/Cargo.toml b/pingora/Cargo.toml index dd890bdb..8e30130a 100644 --- a/pingora/Cargo.toml +++ b/pingora/Cargo.toml @@ -29,6 +29,7 @@ pingora-load-balancing = { version = "0.8.0", path = "../pingora-load-balancing" pingora-proxy = { version = "0.8.0", path = "../pingora-proxy", optional = true, default-features = false } pingora-cache = { version = "0.8.0", path = "../pingora-cache", optional = true, default-features = false } + # Only used for documenting features, but doesn't work in any other dependency # group :( document-features = { version = "0.2.10", optional = true } @@ -37,18 +38,21 @@ document-features = { version = "0.2.10", optional = true } clap = { version = "4.5", features = ["derive"] } tokio = { workspace = true, features = ["rt-multi-thread", "signal"] } env_logger = "0.11" -reqwest = { version = "0.11", features = ["rustls"], default-features = false } -hyper = "0.14" +reqwest = { version = "0.12", features = ["rustls-tls", "http2"], default-features = false } +hyper = { version = "1", features = ["client", "http1", "http2"] } +hyper-util = { version = "0.1", features = ["client-legacy", "http1", "http2"] } +http-body-util = "0.1" async-trait = { workspace = true } http = { workspace = true } log = { workspace = true } +pingora-prometheus = { version = "0.8.0", path = "../pingora-prometheus" } prometheus = "0.14" once_cell = { workspace = true } bytes = { workspace = true } regex = "1" [target.'cfg(unix)'.dev-dependencies] -hyperlocal = "0.8" +hyperlocal = "0.9" jemallocator = "0.5" [features] @@ -126,6 +130,13 @@ time = [] ## Enable sentry for error notifications sentry = ["pingora-core/sentry"] +## Enable upstream modules: the `adjust_upstream_modules` callback, the +## `upstream_modules_ctx` on Session, and `init_upstream_modules` on ProxyHttp. +## +## Allows registering custom upstream modules that process response tasks +## before `upstream_compression`, and configuring them. +upstream_modules = ["pingora-proxy?/upstream_modules"] + ## Enable pre-TLS connection filtering connection_filter = [ "pingora-core/connection_filter", @@ -146,4 +157,4 @@ document-features = [ "sentry", "connection_filter" ] -prometheus = ["pingora-core/prometheus"] +trace = ["pingora-cache?/trace", "pingora-proxy?/trace"] diff --git a/pingora/examples/graceful_upgrade.rs b/pingora/examples/graceful_upgrade.rs new file mode 100644 index 00000000..5a64ff7f --- /dev/null +++ b/pingora/examples/graceful_upgrade.rs @@ -0,0 +1,186 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! # Graceful Upgrade Example +//! +//! Demonstrates the `daemon_wait_for_ready` feature, which coordinates graceful process upgrades +//! by ensuring the new process is fully bootstrapped before the old one begins shutting down. +//! +//! ## Background +//! +//! In a standard daemonized pingora service, the parent process exits immediately after the +//! daemon fork. During a graceful upgrade, the process manager sends SIGQUIT to the old process +//! as soon as the new process's parent exits — potentially before the new process has finished +//! initializing its backends, consistent hash rings, or other state. This can cause a brief +//! window of 502s. +//! +//! With `daemon_wait_for_ready = true`, the parent instead waits for the daemon to send SIGUSR1 +//! before exiting. The process manager only proceeds to stop the old process once the new one +//! signals that it is ready to serve traffic. +//! +//! ## Service startup order +//! +//! This example sets up the following dependency chain: +//! +//! ```text +//! BackendDiscoveryService HashRingService +//! \ / +//! \ / +//! BootstrapService (socket transfer + SIGUSR1 to parent) +//! ``` +//! +//! The bootstrap service — which handles transferring listening sockets from the old process and +//! sending SIGUSR1 to the parent to signal readiness — only runs after both slow initialization +//! services have completed. This ensures the parent never exits until the new process is truly +//! ready to serve traffic. +//! +//! ## Usage +//! +//! ```bash +//! # Run interactively (no daemonization) +//! cargo run --example graceful_upgrade -p pingora +//! +//! # Run as a daemon +//! cargo run --example graceful_upgrade -p pingora -- -d +//! +//! # Graceful upgrade of a running daemon instance +//! cargo run --example graceful_upgrade -p pingora -- -d -u +//! ``` + +use async_trait::async_trait; +use bytes::Bytes; +use clap::Parser; +use http::{Response, StatusCode}; +use log::info; +use std::num::NonZeroU64; +use std::time::Duration; +use tokio::time::sleep; + +use pingora::apps::http_app::ServeHttp; +use pingora::prelude::Opt; +use pingora::protocols::http::ServerSession; +use pingora::server::configuration::ServerConf; +use pingora::server::{Server, ShutdownWatch}; +use pingora::services::background::{background_service, BackgroundService}; +use pingora::services::listening::Service as ListeningService; + +/// Simulates slow backend discovery — e.g. resolving upstream endpoints from a service registry. +pub struct BackendDiscoveryService; + +#[async_trait] +impl BackgroundService for BackendDiscoveryService { + async fn start(&self, _shutdown: ShutdownWatch) { + info!("BackendDiscoveryService: discovering backends..."); + sleep(Duration::from_secs(2)).await; + info!("BackendDiscoveryService: backends ready"); + } +} + +/// Simulates slow consistent hash ring construction. Runs in parallel with +/// `BackendDiscoveryService`; bootstrap waits for both to complete. +pub struct HashRingService; + +#[async_trait] +impl BackgroundService for HashRingService { + async fn start(&self, _shutdown: ShutdownWatch) { + info!("HashRingService: building consistent hash ring..."); + sleep(Duration::from_secs(3)).await; + info!("HashRingService: hash ring ready"); + } +} + +/// A minimal HTTP service that responds to every request with 200 OK. +/// +/// Accepts an optional `sleep` query parameter specifying how many seconds to wait before +/// responding (e.g. `GET /?sleep=20`). This makes in-flight requests easy to observe during a +/// graceful upgrade: a request with a long sleep that arrives just before the upgrade begins will +/// still be running when the new process starts up, demonstrating that the old process keeps +/// serving until all connections are drained. +pub struct HelloApp; + +#[async_trait] +impl ServeHttp for HelloApp { + async fn response(&self, http_stream: &mut ServerSession) -> Response> { + let delay_secs = http_stream + .req_header() + .uri + .query() + .and_then(|q| { + q.split('&').find_map(|pair| { + let (key, val) = pair.split_once('=')?; + if key == "sleep" { + val.parse::().ok() + } else { + None + } + }) + }) + .unwrap_or(0); + + if delay_secs > 0 { + sleep(Duration::from_secs(delay_secs)).await; + } + + let body = Bytes::from("hello from graceful_upgrade example\n"); + Response::builder() + .status(StatusCode::OK) + .header(http::header::CONTENT_TYPE, "text/plain") + .header(http::header::CONTENT_LENGTH, body.len()) + .body(body.to_vec()) + .unwrap() + } +} + +fn main() { + env_logger::init(); + + let opt = Some(Opt::parse()); + + // Build a ServerConf with daemon_wait_for_ready enabled. + // + // When the server is started with -d (daemon mode), the parent process waits for SIGUSR1 + // before exiting. The daemon sends SIGUSR1 only after the bootstrap service completes — + // which in this example means after both slow services have signaled readiness. + let conf = ServerConf { + daemon: true, + daemon_wait_for_ready: true, + daemon_ready_timeout_seconds: NonZeroU64::new(60), + ..ServerConf::default() + }; + + let mut server = Server::new_with_opt_and_conf(opt, conf); + + // Add the slow initialization services and retain their handles so bootstrap can depend + // on them. Both run in parallel; the slowest (HashRingService at 3s) sets the pace. + let backend_handle = server.add_service(background_service( + "backend_discovery", + BackendDiscoveryService, + )); + let hash_ring_handle = server.add_service(background_service("hash_ring", HashRingService)); + + // bootstrap_as_a_service() registers the bootstrap service (socket transfer from the old + // process + SIGUSR1 to the parent) and returns its ServiceHandle. Declaring the slow + // services as dependencies ensures bootstrap only runs once both are ready. + let bootstrap_handle = server.bootstrap_as_a_service(); + bootstrap_handle.add_dependencies([&backend_handle, &hash_ring_handle]); + + let mut http_service = ListeningService::new("hello_http".to_string(), HelloApp); + http_service.add_tcp("0.0.0.0:8000"); + + server + .add_service(http_service) + .add_dependency(backend_handle); + + server.run_forever(); +} diff --git a/pingora/examples/server.rs b/pingora/examples/server.rs index 1e299140..37a246cd 100644 --- a/pingora/examples/server.rs +++ b/pingora/examples/server.rs @@ -20,8 +20,6 @@ use pingora::protocols::TcpKeepalive; use pingora::server::configuration::Opt; use pingora::server::{Server, ShutdownWatch}; use pingora::services::background::{background_service, BackgroundService}; -#[cfg(feature = "prometheus")] -use pingora::services::listening::Service as ListeningService; use pingora::services::ServiceWithDependents; use async_trait::async_trait; @@ -187,9 +185,7 @@ pub fn main() { &key_path, ); - #[cfg(feature = "prometheus")] - let mut prometheus_service_http = ListeningService::prometheus_http_service(); - #[cfg(feature = "prometheus")] + let mut prometheus_service_http = pingora_prometheus::prometheus_http_service(); prometheus_service_http.add_tcp("127.0.0.1:6150"); let background_service = background_service("example", ExampleBackgroundService {}); @@ -199,7 +195,6 @@ pub fn main() { Box::new(echo_service_http), Box::new(proxy_service), Box::new(proxy_service_ssl), - #[cfg(feature = "prometheus")] Box::new(prometheus_service_http), Box::new(background_service), ]; diff --git a/pingora/tests/pingora_conf.yaml b/pingora/tests/pingora_conf.yaml deleted file mode 100644 index c21ae15a..00000000 --- a/pingora/tests/pingora_conf.yaml +++ /dev/null @@ -1,5 +0,0 @@ ---- -version: 1 -client_bind_to_ipv4: - - 127.0.0.2 -ca_file: tests/keys/server.crt \ No newline at end of file