From ac946f4c45804ea65e523fcba845e3b6d8a56857 Mon Sep 17 00:00:00 2001 From: luoluoyuyu Date: Sun, 12 Apr 2026 21:00:15 +0800 Subject: [PATCH 01/26] feat(streaming): implement checkpoint coordination with catalog-backed crash recovery Introduce a full Chandy-Lamport distributed snapshot mechanism with catalog-persisted safe epochs for exactly-once crash recovery: - Add LSM-Tree state engine with MVCC, Bloom filters, watermark GC, dedicated I/O thread pool with panic isolation (state/ module) - Add JobMasterEvent protocol for checkpoint ACK/Decline signaling - Integrate MemoryController and IoPool into JobManager lifecycle - Implement checkpoint coordinator: periodic barrier injection, ACK collection, and CatalogManager.commit_job_checkpoint() callback - Extend StreamingTableDefinition proto with checkpoint_interval_ms and latest_checkpoint_epoch for durable recovery metadata - Support user-defined checkpoint interval via SQL WITH clause - Restore streaming jobs from catalog with precise epoch recovery Made-with: Cursor --- Cargo.lock | 1 + Cargo.toml | 3 +- Makefile | 86 ++- protocol/proto/storage.proto | 8 + src/coordinator/execution/executor.rs | 22 +- src/coordinator/plan/logical_plan_visitor.rs | 18 + src/coordinator/plan/streaming_table_plan.rs | 3 + src/runtime/streaming/job/job_manager.rs | 228 ++++++- src/runtime/streaming/job/mod.rs | 2 +- src/runtime/streaming/mod.rs | 1 + src/runtime/streaming/protocol/control.rs | 13 + src/runtime/streaming/protocol/mod.rs | 2 + src/runtime/streaming/state/error.rs | 37 ++ src/runtime/streaming/state/io_manager.rs | 144 +++++ src/runtime/streaming/state/metrics.rs | 18 + src/runtime/streaming/state/mod.rs | 25 + src/runtime/streaming/state/operator_state.rs | 586 ++++++++++++++++++ src/server/initializer.rs | 8 +- src/storage/stream_catalog/manager.rs | 71 ++- 19 files changed, 1231 insertions(+), 45 deletions(-) create mode 100644 src/runtime/streaming/state/error.rs create mode 100644 src/runtime/streaming/state/io_manager.rs create mode 100644 src/runtime/streaming/state/metrics.rs create mode 100644 src/runtime/streaming/state/mod.rs create mode 100644 src/runtime/streaming/state/operator_state.rs diff --git a/Cargo.lock b/Cargo.lock index e174c43f..33f0cbf3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2215,6 +2215,7 @@ dependencies = [ "lru", "num_cpus", "parking_lot", + "parquet", "petgraph 0.7.1", "proctitle", "prost", diff --git a/Cargo.toml b/Cargo.toml index 87d4ea03..fb380ff1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,7 +24,7 @@ tokio = { version = "1.0", features = ["macros", "rt-multi-thread", "sync", "tim serde = { version = "1.0", features = ["derive"] } serde_yaml = "0.9" serde_json = "1.0" -uuid = { version = "1.0", features = ["v4"] } +uuid = { version = "1.0", features = ["v4", "v7"] } log = "0.4" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } @@ -51,6 +51,7 @@ arrow = { version = "55", default-features = false } arrow-array = "55" arrow-ipc = "55" arrow-schema = { version = "55", features = ["serde"] } +parquet = "55" futures = "0.3" serde_json_path = "0.7" xxhash-rust = { version = "0.8", features = ["xxh3"] } diff --git a/Makefile b/Makefile index 4daf185b..e914b376 100644 --- a/Makefile +++ b/Makefile @@ -13,12 +13,50 @@ APP_NAME := function-stream VERSION := $(shell grep '^version' Cargo.toml | head -1 | awk -F '"' '{print $$2}') -ARCH := $(shell uname -m) -OS := $(shell uname -s | tr '[:upper:]' '[:lower:]') DATE := $(shell date -u +"%Y-%m-%dT%H:%M:%SZ") +# 1. Auto-detect system environment & normalize architecture +RAW_ARCH := $(shell uname -m) +# Fix macOS M-series returning arm64 while Rust expects aarch64 +ifeq ($(RAW_ARCH), arm64) + ARCH := aarch64 +else ifeq ($(RAW_ARCH), amd64) + ARCH := x86_64 +else + ARCH := $(RAW_ARCH) +endif + +OS := $(shell uname -s | tr '[:upper:]' '[:lower:]') +OS_NAME := $(shell uname -s) + +# 2. Configure RUSTFLAGS and target triple per platform DIST_ROOT := dist -TARGET_DIR := target/release +ifeq ($(OS_NAME), Linux) + # Linux: static-link musl for a truly self-contained, zero-dependency binary + TRIPLE := $(ARCH)-unknown-linux-musl + STATIC_FLAGS := -C target-feature=+crt-static +else ifeq ($(OS_NAME), Darwin) + # macOS: strip symbols but keep dynamic linking (Apple system restriction) + TRIPLE := $(ARCH)-apple-darwin + STATIC_FLAGS := +else ifneq (,$(findstring MINGW,$(OS_NAME))$(findstring MSYS,$(OS_NAME))) + # Windows (Git Bash / MSYS2): static-link MSVC runtime + TRIPLE := $(ARCH)-pc-windows-msvc + STATIC_FLAGS := -C target-feature=+crt-static +else + # Fallback + TRIPLE := $(ARCH)-unknown-linux-gnu + STATIC_FLAGS := +endif + +# 3. Aggressive optimization flags +# opt-level=z : size-oriented, minimize binary footprint +# strip=symbols: remove debug symbol table at link time +# Note: panic=abort is intentionally omitted to preserve stack unwinding +# for better fault tolerance in the streaming runtime +OPTIMIZE_FLAGS := -C opt-level=z -C strip=symbols $(STATIC_FLAGS) + +TARGET_DIR := target/$(TRIPLE)/release PYTHON_ROOT := python WASM_SOURCE := $(PYTHON_ROOT)/functionstream-runtime/target/functionstream-python-runtime.wasm @@ -42,7 +80,7 @@ C_0 := \033[0m log = @printf "$(C_B)[-]$(C_0) %-15s %s\n" "$(1)" "$(2)" success = @printf "$(C_G)[✔]$(C_0) %s\n" "$(1)" -.PHONY: all help build build-lite dist dist-lite clean test env env-clean go-sdk-env go-sdk-build go-sdk-clean docker docker-run docker-push .check-env .build-wasm +.PHONY: all help build build-lite dist dist-lite clean test env env-clean go-sdk-env go-sdk-build go-sdk-clean docker docker-run docker-push .check-env .ensure-target .build-wasm all: build @@ -65,18 +103,42 @@ help: @echo "" @echo " Version: $(VERSION) | Arch: $(ARCH) | OS: $(OS)" -build: .check-env .build-wasm - $(call log,BUILD,Rust Full Features) - @cargo build --release --features python --quiet +# 4. Auto-install missing Rust target toolchain +.ensure-target: + @rustup target list --installed | grep -q "$(TRIPLE)" || \ + (printf "$(C_Y)[!] Auto-installing target toolchain for $(OS_NAME): $(TRIPLE)$(C_0)\n" && \ + rustup target add $(TRIPLE)) + +# 5. Build targets (depend on .ensure-target for automatic toolchain setup) +build: .check-env .ensure-target .build-wasm + $(call log,BUILD,Rust Full [$(OS_NAME) / $(TRIPLE)]) + @RUSTFLAGS="$(OPTIMIZE_FLAGS)" \ + cargo build --release \ + --target $(TRIPLE) \ + --features python \ + --quiet $(call log,BUILD,CLI) - @cargo build --release -p function-stream-cli --quiet + @RUSTFLAGS="$(OPTIMIZE_FLAGS)" \ + cargo build --release \ + --target $(TRIPLE) \ + -p function-stream-cli \ + --quiet $(call success,Target: $(TARGET_DIR)/$(APP_NAME) $(TARGET_DIR)/cli) -build-lite: .check-env - $(call log,BUILD,Rust Lite No Python) - @cargo build --release --no-default-features --features incremental-cache --quiet +build-lite: .check-env .ensure-target + $(call log,BUILD,Rust Lite [$(OS_NAME) / $(TRIPLE)]) + @RUSTFLAGS="$(OPTIMIZE_FLAGS)" \ + cargo build --release \ + --target $(TRIPLE) \ + --no-default-features \ + --features incremental-cache \ + --quiet $(call log,BUILD,CLI for dist) - @cargo build --release -p function-stream-cli --quiet + @RUSTFLAGS="$(OPTIMIZE_FLAGS)" \ + cargo build --release \ + --target $(TRIPLE) \ + -p function-stream-cli \ + --quiet $(call success,Target: $(TARGET_DIR)/$(APP_NAME) $(TARGET_DIR)/cli) .build-wasm: diff --git a/protocol/proto/storage.proto b/protocol/proto/storage.proto index d7caf7bc..20e14862 100644 --- a/protocol/proto/storage.proto +++ b/protocol/proto/storage.proto @@ -52,6 +52,14 @@ message StreamingTableDefinition { // Stored as opaque bytes to avoid coupling storage schema with runtime API protos. bytes fs_program_bytes = 3; string comment = 4; + + // User-specified checkpoint interval from WITH clause (e.g. 'checkpoint.interval' = '5000'). + // 0 or unset means use system default. + uint64 checkpoint_interval_ms = 5; + + // Last globally-committed checkpoint epoch. + // Updated by JobManager after all operators ACK. Used for crash recovery. + uint64 latest_checkpoint_epoch = 6; } // ============================================================================= diff --git a/src/coordinator/execution/executor.rs b/src/coordinator/execution/executor.rs index 0000d0cf..399ab775 100644 --- a/src/coordinator/execution/executor.rs +++ b/src/coordinator/execution/executor.rs @@ -322,25 +322,33 @@ impl PlanVisitor for Executor { let job_manager: Arc = Arc::clone(&self.job_manager); let job_id = plan.name.clone(); - let job_id = tokio::task::block_in_place(|| { - tokio::runtime::Handle::current() - .block_on(job_manager.submit_job(job_id, fs_program.clone())) - }) - .map_err(|e| ExecuteError::Internal(format!("Failed to submit streaming job: {e}")))?; + + let custom_interval: Option = plan + .with_options + .as_ref() + .and_then(|opts| opts.get("checkpoint.interval")) + .and_then(|v| v.parse().ok()); self.catalog_manager .persist_streaming_job( &plan.name, &fs_program, plan.comment.as_deref().unwrap_or(""), + custom_interval.unwrap_or(0), ) .map_err(|e| { ExecuteError::Internal(format!( - "Streaming job '{}' submitted but persistence failed: {e}", - plan.name + "Streaming job persistence failed: {e}", )) })?; + let job_id = tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on( + job_manager.submit_job(job_id, fs_program, custom_interval, None), + ) + }) + .map_err(|e| ExecuteError::Internal(format!("Failed to submit streaming job: {e}")))?; + info!( job_id = %job_id, table = %plan.name, diff --git a/src/coordinator/plan/logical_plan_visitor.rs b/src/coordinator/plan/logical_plan_visitor.rs index 6adc6420..d49d0314 100644 --- a/src/coordinator/plan/logical_plan_visitor.rs +++ b/src/coordinator/plan/logical_plan_visitor.rs @@ -168,10 +168,28 @@ impl LogicalPlanVisitor { let validated_program = self.validate_graph_topology(&final_logical_plan)?; + let streaming_with_options: Option> = + if with_options.is_empty() { + None + } else { + let map: std::collections::HashMap = with_options + .iter() + .filter_map(|opt| match opt { + SqlOption::KeyValue { key, value } => Some(( + key.value.clone(), + value.to_string().trim_matches('\'').to_string(), + )), + _ => None, + }) + .collect(); + if map.is_empty() { None } else { Some(map) } + }; + Ok(StreamingTable { name: sink_table_name, comment: comment.clone(), program: validated_program, + with_options: streaming_with_options, }) } diff --git a/src/coordinator/plan/streaming_table_plan.rs b/src/coordinator/plan/streaming_table_plan.rs index 512ec266..e155ba91 100644 --- a/src/coordinator/plan/streaming_table_plan.rs +++ b/src/coordinator/plan/streaming_table_plan.rs @@ -10,6 +10,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; + use super::{PlanNode, PlanVisitor, PlanVisitorContext, PlanVisitorResult}; use crate::sql::logical_node::logical::LogicalProgram; @@ -19,6 +21,7 @@ pub struct StreamingTable { pub name: String, pub comment: Option, pub program: LogicalProgram, + pub with_options: Option>, } impl PlanNode for StreamingTable { diff --git a/src/runtime/streaming/job/job_manager.rs b/src/runtime/streaming/job/job_manager.rs index 3082dc56..011a912e 100644 --- a/src/runtime/streaming/job/job_manager.rs +++ b/src/runtime/streaming/job/job_manager.rs @@ -10,13 +10,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; -use std::sync::{Arc, OnceLock, RwLock}; +use std::collections::{HashMap, HashSet}; +use std::path::{Path, PathBuf}; +use std::sync::{Arc, Mutex, OnceLock, RwLock}; +use std::time::Duration; use anyhow::{Context, Result, anyhow, bail, ensure}; use tokio::sync::mpsc; +use tokio::task::JoinHandle as TokioJoinHandle; use tokio_stream::wrappers::ReceiverStream; -use tracing::{error, info, warn}; +use tracing::{error, info, warn, debug}; use protocol::function_stream_graph::{ChainedOperator, FsProgram}; @@ -31,7 +34,10 @@ use crate::runtime::streaming::job::models::{ }; use crate::runtime::streaming::memory::MemoryPool; use crate::runtime::streaming::network::endpoint::{BoxedEventStream, PhysicalSender}; -use crate::runtime::streaming::protocol::control::{ControlCommand, StopMode}; +use crate::runtime::streaming::protocol::control::{ControlCommand, StopMode, JobMasterEvent}; +use crate::runtime::streaming::protocol::event::CheckpointBarrier; +use crate::runtime::streaming::state::{IoManager, IoPool, MemoryController, NoopMetricsCollector}; +use crate::storage::stream_catalog::CatalogManager; #[derive(Debug, Clone)] pub struct StreamingJobSummary { @@ -57,12 +63,39 @@ pub struct StreamingJobDetail { pub program: FsProgram, } +#[derive(Debug, Clone)] +pub struct StateConfig { + pub max_background_spills: usize, + pub max_background_compactions: usize, + pub soft_limit_ratio: f64, + pub checkpoint_interval_ms: u64, +} + +impl Default for StateConfig { + fn default() -> Self { + Self { + max_background_spills: 4, + max_background_compactions: 2, + soft_limit_ratio: 0.7, + checkpoint_interval_ms: 10_000, + } + } +} + static GLOBAL_JOB_MANAGER: OnceLock> = OnceLock::new(); pub struct JobManager { active_jobs: Arc>>, operator_factory: Arc, memory_pool: Arc, + + #[allow(dead_code)] + memory_controller: Arc, + #[allow(dead_code)] + io_manager_client: IoManager, + io_pool: Mutex>, + state_base_dir: PathBuf, + state_config: StateConfig, } struct PreparedChain { @@ -85,17 +118,42 @@ impl PipelineRunner { } impl JobManager { - pub fn new(operator_factory: Arc, max_memory_bytes: usize) -> Self { - Self { + pub fn new( + operator_factory: Arc, + max_memory_bytes: usize, + state_base_dir: impl AsRef, + state_config: StateConfig, + ) -> Result { + let soft_limit_bytes = (max_memory_bytes as f64 * state_config.soft_limit_ratio) as usize; + let memory_controller = MemoryController::new(soft_limit_bytes, max_memory_bytes); + + let metrics = Arc::new(NoopMetricsCollector); + let (io_pool, io_manager_client) = IoPool::try_new( + state_config.max_background_spills, + state_config.max_background_compactions, + metrics, + ).context("Failed to initialize state engine I/O pool")?; + + Ok(Self { active_jobs: Arc::new(RwLock::new(HashMap::new())), operator_factory, memory_pool: MemoryPool::new(max_memory_bytes), - } + memory_controller, + io_manager_client, + io_pool: Mutex::new(Some(io_pool)), + state_base_dir: state_base_dir.as_ref().to_path_buf(), + state_config, + }) } - pub fn init(factory: Arc, memory_bytes: usize) -> Result<()> { + pub fn init( + factory: Arc, + memory_bytes: usize, + state_base_dir: PathBuf, + state_config: StateConfig, + ) -> Result<()> { GLOBAL_JOB_MANAGER - .set(Arc::new(Self::new(factory, memory_bytes))) + .set(Arc::new(Self::new(factory, memory_bytes, state_base_dir, state_config)?)) .map_err(|_| anyhow!("JobManager singleton already initialized")) } @@ -106,19 +164,44 @@ impl JobManager { .ok_or_else(|| anyhow!("JobManager not initialized. Call init() first.")) } - pub async fn submit_job(&self, job_id: String, program: FsProgram) -> Result { + pub fn shutdown(&self) { + if let Some(pool) = self.io_pool.lock().unwrap().take() { + pool.shutdown(); + } + } + + pub async fn submit_job( + &self, + job_id: String, + program: FsProgram, + custom_checkpoint_interval_ms: Option, + recovery_epoch: Option, + ) -> Result { let mut edge_manager = EdgeManager::build(&program.nodes, &program.edges); let mut pipelines = HashMap::with_capacity(program.nodes.len()); + let mut source_control_txs = Vec::new(); + let mut expected_pipeline_ids = HashSet::new(); + + let job_state_dir = self.state_base_dir.join(&job_id); + std::fs::create_dir_all(&job_state_dir).context("Failed to create job state dir")?; + + let (job_master_tx, job_master_rx) = mpsc::channel(256); + + let safe_epoch = recovery_epoch.unwrap_or(0); + for node in &program.nodes { let pipeline_id = node.node_index as u32; - let pipeline = self + let (pipeline, is_source) = self .build_and_spawn_pipeline( job_id.clone(), pipeline_id, &node.operators, &mut edge_manager, + &job_state_dir, + job_master_tx.clone(), + safe_epoch, ) .with_context(|| { format!( @@ -127,9 +210,25 @@ impl JobManager { ) })?; + if is_source { + source_control_txs.push(pipeline.control_tx.clone()); + } + expected_pipeline_ids.insert(pipeline_id); pipelines.insert(pipeline_id, pipeline); } + let interval_ms = custom_checkpoint_interval_ms + .unwrap_or(self.state_config.checkpoint_interval_ms); + + self.spawn_checkpoint_coordinator( + job_id.clone(), + source_control_txs, + job_master_rx, + expected_pipeline_ids, + interval_ms, + safe_epoch + 1, + ); + let graph = PhysicalExecutionGraph { job_id: job_id.clone(), program, @@ -143,7 +242,7 @@ impl JobManager { .map_err(|e| anyhow!("Active jobs lock poisoned: {}", e))?; jobs_guard.insert(job_id.clone(), graph); - info!(job_id = %job_id, "Job submitted successfully."); + info!(job_id = %job_id, interval_ms, recovery_epoch = safe_epoch, "Job submitted successfully."); Ok(job_id) } @@ -326,7 +425,10 @@ impl JobManager { pipeline_id: u32, operators: &[ChainedOperator], edge_manager: &mut EdgeManager, - ) -> Result { + _job_state_dir: &Path, + _job_master_tx: mpsc::Sender, + _recovery_epoch: u64, + ) -> Result<(PhysicalPipeline, bool)> { let (raw_inboxes, raw_outboxes) = edge_manager.take_endpoints(pipeline_id).with_context(|| { format!( @@ -352,6 +454,8 @@ impl JobManager { ) })?; + let is_source = chain.source.is_some(); + ensure!( chain.source.is_some() || !physical_inboxes.is_empty(), "Topology Error: Pipeline '{}' contains no source and has no upstream inputs (Dead end).", @@ -392,12 +496,13 @@ impl JobManager { .spawn_worker_thread(job_id, pipeline_id, runner, Arc::clone(&status)) .with_context(|| format!("Failed to spawn OS thread for pipeline {}", pipeline_id))?; - Ok(PhysicalPipeline { + let pipeline = PhysicalPipeline { pipeline_id, handle: Some(handle), status, control_tx, - }) + }; + Ok((pipeline, is_source)) } fn build_operator_chain(&self, operator_configs: &[ChainedOperator]) -> Result { @@ -509,4 +614,97 @@ impl JobManager { warn!(job_id = %job_id, pipeline_id = pipeline_id, "Pipeline failure detected. Job degraded."); } } + + // ======================================================================== + // Chandy-Lamport distributed snapshot barrier coordinator + // ======================================================================== + + fn spawn_checkpoint_coordinator( + &self, + job_id: String, + source_control_txs: Vec>, + mut job_master_rx: mpsc::Receiver, + expected_pipeline_ids: HashSet, + interval_ms: u64, + start_epoch: u64, + ) -> TokioJoinHandle<()> { + tokio::spawn(async move { + if interval_ms == 0 { + info!(job_id = %job_id, "Checkpoint disabled for this job"); + return; + } + + let mut interval = tokio::time::interval(Duration::from_millis(interval_ms)); + interval.tick().await; + + let mut current_epoch: u64 = start_epoch; + let mut pending_checkpoints: HashMap> = HashMap::new(); + + loop { + tokio::select! { + _ = interval.tick() => { + info!(job_id = %job_id, epoch = current_epoch, "Triggering global Checkpoint Barrier."); + pending_checkpoints.insert(current_epoch, expected_pipeline_ids.clone()); + + let barrier = CheckpointBarrier { + epoch: current_epoch as u32, + min_epoch: 0, + timestamp: std::time::SystemTime::now(), + then_stop: false, + }; + + for tx in &source_control_txs { + let cmd = ControlCommand::trigger_checkpoint(barrier); + if tx.send(cmd).await.is_err() { + debug!(job_id = %job_id, "Source disconnected. Shutting down coordinator."); + return; + } + } + current_epoch += 1; + } + + Some(event) = job_master_rx.recv() => { + match event { + JobMasterEvent::CheckpointAck { pipeline_id, epoch } => { + if let Some(pending_set) = pending_checkpoints.get_mut(&epoch) { + pending_set.remove(&pipeline_id); + + if pending_set.is_empty() { + info!( + job_id = %job_id, epoch = epoch, + "Checkpoint Epoch is GLOBALLY COMPLETED!" + ); + + if let Some(catalog) = CatalogManager::try_global() { + if let Err(e) = catalog.commit_job_checkpoint(&job_id, epoch) { + error!( + job_id = %job_id, epoch = epoch, + error = %e, + "Failed to commit checkpoint metadata to Catalog" + ); + } + } else { + warn!( + job_id = %job_id, epoch = epoch, + "CatalogManager not available, checkpoint not persisted globally" + ); + } + + pending_checkpoints.remove(&epoch); + } + } + } + JobMasterEvent::CheckpointDecline { pipeline_id, epoch, reason } => { + error!( + job_id = %job_id, epoch = epoch, pipeline_id = pipeline_id, + reason = %reason, "Checkpoint FAILED!" + ); + pending_checkpoints.remove(&epoch); + } + } + } + } + } + }) + } } diff --git a/src/runtime/streaming/job/mod.rs b/src/runtime/streaming/job/mod.rs index 02e0343c..59d5c61f 100644 --- a/src/runtime/streaming/job/mod.rs +++ b/src/runtime/streaming/job/mod.rs @@ -14,4 +14,4 @@ pub mod edge_manager; pub mod job_manager; pub mod models; -pub use job_manager::{JobManager, StreamingJobSummary}; +pub use job_manager::{JobManager, StateConfig, StreamingJobSummary}; diff --git a/src/runtime/streaming/mod.rs b/src/runtime/streaming/mod.rs index 7e0ba57a..0e4e6758 100644 --- a/src/runtime/streaming/mod.rs +++ b/src/runtime/streaming/mod.rs @@ -23,5 +23,6 @@ pub mod memory; pub mod network; pub mod operators; pub mod protocol; +pub mod state; pub use protocol::StreamOutput; diff --git a/src/runtime/streaming/protocol/control.rs b/src/runtime/streaming/protocol/control.rs index 3b23cb09..e87ccd3b 100644 --- a/src/runtime/streaming/protocol/control.rs +++ b/src/runtime/streaming/protocol/control.rs @@ -79,3 +79,16 @@ pub enum StopMode { pub fn control_channel(capacity: usize) -> (Sender, Receiver) { mpsc::channel(capacity) } + +#[derive(Debug, Clone)] +pub enum JobMasterEvent { + CheckpointAck { + pipeline_id: u32, + epoch: u64, + }, + CheckpointDecline { + pipeline_id: u32, + epoch: u64, + reason: String, + }, +} diff --git a/src/runtime/streaming/protocol/mod.rs b/src/runtime/streaming/protocol/mod.rs index e91e8d8c..28fd85a4 100644 --- a/src/runtime/streaming/protocol/mod.rs +++ b/src/runtime/streaming/protocol/mod.rs @@ -13,4 +13,6 @@ pub mod control; pub mod event; +#[allow(unused_imports)] +pub use control::{ControlCommand, JobMasterEvent, StopMode}; pub use event::{CheckpointBarrier, StreamOutput, Watermark}; diff --git a/src/runtime/streaming/state/error.rs b/src/runtime/streaming/state/error.rs new file mode 100644 index 00000000..e04a022e --- /dev/null +++ b/src/runtime/streaming/state/error.rs @@ -0,0 +1,37 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. + +use thiserror::Error; +use crossbeam_channel::TrySendError; + +#[derive(Error, Debug)] +pub enum StateEngineError { + #[error("I/O error during state persistence: {0}")] + IoError(#[from] std::io::Error), + + #[error("Parquet serialization/deserialization failed: {0}")] + ParquetError(#[from] parquet::errors::ParquetError), + + #[error("Arrow computation failed: {0}")] + ArrowError(#[from] arrow::error::ArrowError), + + #[error("Memory hard limit exceeded and spill channel is full")] + MemoryBackpressureTimeout, + + #[error("Background I/O pool has been shut down or disconnected")] + IoPoolDisconnected, + + #[error("State metadata corrupted: {0}")] + Corruption(String), +} + +pub type Result = std::result::Result; + +impl From> for StateEngineError { + fn from(err: TrySendError) -> Self { + match err { + TrySendError::Full(_) => StateEngineError::MemoryBackpressureTimeout, + TrySendError::Disconnected(_) => StateEngineError::IoPoolDisconnected, + } + } +} diff --git a/src/runtime/streaming/state/io_manager.rs b/src/runtime/streaming/state/io_manager.rs new file mode 100644 index 00000000..aa85385b --- /dev/null +++ b/src/runtime/streaming/state/io_manager.rs @@ -0,0 +1,144 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. + +#[allow(unused_imports)] +use super::error::StateEngineError; +use super::metrics::StateMetricsCollector; +use super::operator_state::{MemTable, OperatorStateStore, TombstoneMap}; +use crossbeam_channel::{bounded, Receiver, Sender, TrySendError}; +use std::panic::{catch_unwind, AssertUnwindSafe}; +use std::sync::Arc; +use std::thread::{self, JoinHandle}; +use std::time::Instant; + +pub struct SpillJob { + pub store: Arc, + pub epoch: u64, + pub data: MemTable, + pub tombstone_snapshot: TombstoneMap, +} + +pub enum CompactJob { + Minor { store: Arc }, + Major { store: Arc }, +} + +pub struct IoPool { + spill_tx: Option>, + compact_tx: Option>, + worker_handles: Vec>, +} + +impl IoPool { + pub fn try_new( + spill_threads: usize, + compact_threads: usize, + metrics: Arc, + ) -> std::io::Result<(Self, IoManager)> { + let (spill_tx, spill_rx) = bounded::(1024); + let (compact_tx, compact_rx) = bounded::(256); + let mut worker_handles = Vec::with_capacity(spill_threads + compact_threads); + + for i in 0..spill_threads.max(1) { + let rx = spill_rx.clone(); + let m = metrics.clone(); + let handle = thread::Builder::new() + .name(format!("fs-spill-worker-{i}")) + .spawn(move || spill_worker_loop(rx, m))?; + worker_handles.push(handle); + } + + for i in 0..compact_threads.max(1) { + let rx = compact_rx.clone(); + let m = metrics.clone(); + let handle = thread::Builder::new() + .name(format!("fs-compact-worker-{i}")) + .spawn(move || compact_worker_loop(rx, m))?; + worker_handles.push(handle); + } + + let manager = IoManager { + spill_tx: spill_tx.clone(), + compact_tx: compact_tx.clone(), + }; + + Ok(( + Self { spill_tx: Some(spill_tx), compact_tx: Some(compact_tx), worker_handles }, + manager, + )) + } + + pub fn shutdown(mut self) { + tracing::info!("Initiating graceful shutdown for IoPool..."); + self.spill_tx.take(); + self.compact_tx.take(); + for handle in self.worker_handles.drain(..) { + if let Err(e) = handle.join() { + tracing::error!("I/O Worker thread panicked during shutdown: {:?}", e); + } + } + tracing::info!("IoPool graceful shutdown completed."); + } +} + +#[derive(Clone)] +pub struct IoManager { + spill_tx: Sender, + compact_tx: Sender, +} + +impl IoManager { + pub fn try_send_spill(&self, job: SpillJob) -> Result<(), TrySendError> { + self.spill_tx.try_send(job) + } + pub fn try_send_compact(&self, job: CompactJob) -> Result<(), TrySendError> { + self.compact_tx.try_send(job) + } + pub fn pending_spills(&self) -> usize { self.spill_tx.len() } +} + +fn spill_worker_loop(rx: Receiver, metrics: Arc) { + while let Ok(job) = rx.recv() { + let op_id = job.store.operator_id; + let epoch = job.epoch; + let start = Instant::now(); + + let result = catch_unwind(AssertUnwindSafe(|| { + job.store.execute_spill_sync(job.epoch, job.data, job.tombstone_snapshot, &metrics) + })); + + let duration_ms = start.elapsed().as_millis(); + metrics.record_spill_duration(op_id, duration_ms); + + match result { + Ok(Ok(())) => tracing::debug!(op_id, epoch, duration_ms, "Spill success"), + Ok(Err(e)) => tracing::error!(op_id, epoch, duration_ms, %e, "Spill I/O Error"), + Err(_) => tracing::error!(op_id, epoch, "CRITICAL: Spill thread PANICKED! Recovered."), + } + } +} + +fn compact_worker_loop(rx: Receiver, metrics: Arc) { + while let Ok(job) = rx.recv() { + let (store, is_major) = match job { + CompactJob::Minor { store } => (store, false), + CompactJob::Major { store } => (store, true), + }; + + let op_id = store.operator_id; + let start = Instant::now(); + + let result = catch_unwind(AssertUnwindSafe(|| { + store.execute_compact_sync(is_major, &metrics) + })); + + let duration_ms = start.elapsed().as_millis(); + metrics.record_compaction_duration(op_id, is_major, duration_ms); + + match result { + Ok(Ok(())) => tracing::info!(op_id, is_major, duration_ms, "Compaction success"), + Ok(Err(e)) => tracing::error!(op_id, is_major, duration_ms, %e, "Compaction I/O Error"), + Err(_) => tracing::error!(op_id, is_major, "CRITICAL: Compact thread PANICKED!"), + } + } +} diff --git a/src/runtime/streaming/state/metrics.rs b/src/runtime/streaming/state/metrics.rs new file mode 100644 index 00000000..c6d5ae4e --- /dev/null +++ b/src/runtime/streaming/state/metrics.rs @@ -0,0 +1,18 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. + +pub trait StateMetricsCollector: Send + Sync + 'static { + fn record_memory_usage(&self, operator_id: u32, bytes: usize); + fn record_spill_duration(&self, operator_id: u32, duration_ms: u128); + fn record_compaction_duration(&self, operator_id: u32, is_major: bool, duration_ms: u128); + fn inc_io_errors(&self, operator_id: u32); +} + +/// Default no-op implementation. +pub struct NoopMetricsCollector; +impl StateMetricsCollector for NoopMetricsCollector { + fn record_memory_usage(&self, _: u32, _: usize) {} + fn record_spill_duration(&self, _: u32, _: u128) {} + fn record_compaction_duration(&self, _: u32, _: bool, _: u128) {} + fn inc_io_errors(&self, _: u32) {} +} diff --git a/src/runtime/streaming/state/mod.rs b/src/runtime/streaming/state/mod.rs new file mode 100644 index 00000000..07e57527 --- /dev/null +++ b/src/runtime/streaming/state/mod.rs @@ -0,0 +1,25 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub mod error; +pub mod metrics; +mod io_manager; +mod operator_state; + +#[allow(unused_imports)] +pub use error::{StateEngineError, Result}; +#[allow(unused_imports)] +pub use metrics::{StateMetricsCollector, NoopMetricsCollector}; +#[allow(unused_imports)] +pub use io_manager::{CompactJob, IoManager, IoPool, SpillJob}; +#[allow(unused_imports)] +pub use operator_state::{MemoryController, OperatorStateStore}; diff --git a/src/runtime/streaming/state/operator_state.rs b/src/runtime/streaming/state/operator_state.rs new file mode 100644 index 00000000..f420d2d6 --- /dev/null +++ b/src/runtime/streaming/state/operator_state.rs @@ -0,0 +1,586 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. + +use super::error::{Result, StateEngineError}; +use super::io_manager::{CompactJob, IoManager, SpillJob}; +use super::metrics::StateMetricsCollector; +use arrow_array::builder::{BinaryBuilder, BooleanBuilder, UInt64Builder}; +use arrow_array::{Array, BinaryArray, RecordBatch, UInt64Array}; +use arrow_schema::{DataType, Field, Schema}; +use crossbeam_channel::TrySendError; +use parking_lot::{Mutex, RwLock}; +use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; +use parquet::arrow::{ArrowWriter, ProjectionMask}; +use parquet::file::properties::WriterProperties; +use std::collections::{HashMap, HashSet, VecDeque}; +use std::fs::{self, File}; +use std::path::{Path, PathBuf}; +use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; +use std::sync::Arc; +use tokio::sync::Notify; +use uuid::Uuid; + +pub(crate) const PARTITION_KEY_COL: &str = "__fs_partition_key"; + +pub type PartitionKey = Vec; +pub type MemTable = HashMap>; +pub type TombstoneMap = HashMap; + +const TOMBSTONE_ENTRY_OVERHEAD: usize = 8 + 16; + +#[derive(Debug)] +pub struct MemoryController { + current_usage: AtomicUsize, + hard_limit: usize, + soft_limit: usize, +} + +impl MemoryController { + pub fn new(soft_limit: usize, hard_limit: usize) -> Arc { + Arc::new(Self { current_usage: AtomicUsize::new(0), hard_limit, soft_limit }) + } + pub fn exceeds_hard_limit(&self, incoming: usize) -> bool { + self.current_usage.load(Ordering::Relaxed) + incoming > self.hard_limit + } + pub fn should_spill(&self) -> bool { + self.current_usage.load(Ordering::Relaxed) > self.soft_limit + } + pub fn record_inc(&self, bytes: usize) { + self.current_usage.fetch_add(bytes, Ordering::Relaxed); + } + pub fn record_dec(&self, bytes: usize) { + self.current_usage.fetch_sub(bytes, Ordering::Relaxed); + } + pub fn usage_bytes(&self) -> usize { + self.current_usage.load(Ordering::Relaxed) + } +} + +pub struct OperatorStateStore { + pub operator_id: u32, + current_epoch: AtomicU64, + + active_table: RwLock, + immutable_tables: Mutex>, + + data_files: RwLock>, + tombstone_files: RwLock>, + tombstones: RwLock, + + mem_ctrl: Arc, + io_manager: IoManager, + + data_dir: PathBuf, + tombstone_dir: PathBuf, + + spill_notify: Arc, + is_spilling: AtomicBool, + is_compacting: AtomicBool, +} + +impl OperatorStateStore { + pub fn new( + operator_id: u32, + base_dir: impl AsRef, + mem_ctrl: Arc, + io_manager: IoManager, + ) -> Result> { + let op_dir = base_dir.as_ref().join(format!("op_{operator_id}")); + let data_dir = op_dir.join("data"); + let tombstone_dir = op_dir.join("tombstones"); + + fs::create_dir_all(&data_dir).map_err(StateEngineError::IoError)?; + fs::create_dir_all(&tombstone_dir).map_err(StateEngineError::IoError)?; + + Ok(Arc::new(Self { + operator_id, + current_epoch: AtomicU64::new(1), + active_table: RwLock::new(HashMap::new()), + immutable_tables: Mutex::new(VecDeque::new()), + data_files: RwLock::new(Vec::new()), + tombstone_files: RwLock::new(Vec::new()), + tombstones: RwLock::new(HashMap::new()), + mem_ctrl, + io_manager, + data_dir, + tombstone_dir, + spill_notify: Arc::new(Notify::new()), + is_spilling: AtomicBool::new(false), + is_compacting: AtomicBool::new(false), + })) + } + + pub async fn put(self: &Arc, key: PartitionKey, batch: RecordBatch) -> Result<()> { + let size = batch.get_array_memory_size(); + while self.mem_ctrl.exceeds_hard_limit(size) { + self.trigger_spill(); + self.spill_notify.notified().await; + } + + self.mem_ctrl.record_inc(size); + self.active_table.write().entry(key).or_default().push(batch); + + if self.mem_ctrl.should_spill() { + self.downgrade_active_table(self.current_epoch.load(Ordering::Acquire)); + self.trigger_spill(); + } + Ok(()) + } + + pub fn remove_batches(&self, key: PartitionKey) -> Result<()> { + let current_ep = self.current_epoch.load(Ordering::Acquire); + let tombstone_mem_size = key.len() + TOMBSTONE_ENTRY_OVERHEAD; + + { + let mut tb_guard = self.tombstones.write(); + if tb_guard.insert(key.clone(), current_ep).is_none() { + self.mem_ctrl.record_inc(tombstone_mem_size); + } + } + + if let Some(batches) = self.active_table.write().remove(&key) { + let released: usize = batches.iter().map(|b| b.get_array_memory_size()).sum(); + self.mem_ctrl.record_dec(released); + } + + let mut imm = self.immutable_tables.lock(); + for (_, table) in imm.iter_mut() { + if let Some(batches) = table.remove(&key) { + let released: usize = batches.iter().map(|b| b.get_array_memory_size()).sum(); + self.mem_ctrl.record_dec(released); + } + } + + Ok(()) + } + + pub fn snapshot_epoch(self: &Arc, epoch: u64) -> Result<()> { + self.downgrade_active_table(epoch); + self.trigger_spill(); + self.current_epoch.store(epoch.saturating_add(1), Ordering::Release); + Ok(()) + } + + fn downgrade_active_table(&self, epoch: u64) { + let mut active_guard = self.active_table.write(); + if active_guard.is_empty() { return; } + let old_active = std::mem::take(&mut *active_guard); + self.immutable_tables.lock().push_back((epoch, old_active)); + } + + pub async fn get_batches(&self, key: &[u8]) -> Result> { + let deleted_epoch = self.tombstones.read().get(key).copied(); + let mut out = Vec::new(); + + if let Some(batches) = self.active_table.read().get(key) { + out.extend(batches.clone()); + } + + for (table_epoch, table) in self.immutable_tables.lock().iter().rev() { + if let Some(del_ep) = deleted_epoch { + if *table_epoch <= del_ep { continue; } + } + if let Some(batches) = table.get(key) { + out.extend(batches.clone()); + } + } + + let paths: Vec = self.data_files.read().clone(); + if paths.is_empty() { return Ok(out); } + + let pk = key.to_vec(); + let merged = tokio::task::spawn_blocking(move || -> Result> { + let mut acc = Vec::new(); + for path in paths { + let file_epoch = extract_epoch(&path); + if let Some(del_ep) = deleted_epoch { + if file_epoch <= del_ep { continue; } + } + + // Native Bloom Filter intercepts empty reads here + let file = File::open(&path).map_err(StateEngineError::IoError)?; + let mut reader = ParquetRecordBatchReaderBuilder::try_new(file)?.build()?; + for maybe in reader.by_ref() { + if let Some(filtered) = filter_and_strip_partition_key(&maybe?, &pk)? { + acc.push(filtered); + } + } + } + Ok(acc) + }).await.map_err(|_| StateEngineError::Corruption("Tokio task panicked".into()))??; + + out.extend(merged); + Ok(out) + } + + fn trigger_spill(self: &Arc) { + if !self.is_spilling.swap(true, Ordering::SeqCst) { + let target = self.immutable_tables.lock().pop_front(); + let Some((epoch, data)) = target else { + self.is_spilling.store(false, Ordering::SeqCst); + self.spill_notify.notify_waiters(); + return; + }; + + let tombstone_snapshot = self.tombstones.read().clone(); + let job = SpillJob { store: self.clone(), epoch, data, tombstone_snapshot }; + + match self.io_manager.try_send_spill(job) { + Ok(()) => {} + Err(TrySendError::Full(j)) | Err(TrySendError::Disconnected(j)) => { + self.immutable_tables.lock().push_front((j.epoch, j.data)); + self.is_spilling.store(false, Ordering::SeqCst); + self.spill_notify.notify_waiters(); + } + } + } + } + + pub fn trigger_minor_compaction(self: &Arc) { + if !self.is_compacting.swap(true, Ordering::SeqCst) { + let _ = self.io_manager.try_send_compact(CompactJob::Minor { store: self.clone() }); + } + } + + pub fn trigger_major_compaction(self: &Arc) { + if !self.is_compacting.swap(true, Ordering::SeqCst) { + let _ = self.io_manager.try_send_compact(CompactJob::Major { store: self.clone() }); + } + } + + pub(crate) fn execute_spill_sync( + self: &Arc, + epoch: u64, + data: MemTable, + tombstones: TombstoneMap, + metrics: &Arc, + ) -> Result<()> { + let mut batches_to_write = Vec::new(); + let mut size_to_release: usize = 0; + let distinct_keys_count = data.len() as u64; + + for (key, batches) in data { + for batch in batches { + size_to_release += batch.get_array_memory_size(); + batches_to_write.push(inject_partition_key(&batch, &key)?); + } + } + + if !batches_to_write.is_empty() { + let path = self.data_dir.join(Self::generate_data_file_name(epoch)); + if let Err(e) = write_parquet_with_bloom_atomic(&path, &batches_to_write, distinct_keys_count) { + metrics.inc_io_errors(self.operator_id); + let restored = restore_memtable_from_injected_batches(batches_to_write)?; + self.immutable_tables.lock().push_front((epoch, restored)); + self.is_spilling.store(false, Ordering::SeqCst); + self.spill_notify.notify_waiters(); + return Err(e); + } + self.data_files.write().push(path); + } + + if !tombstones.is_empty() { + let mut key_builder = BinaryBuilder::new(); + let mut epoch_builder = UInt64Builder::new(); + let tomb_ndv = tombstones.len() as u64; + + for (key, del_epoch) in tombstones.iter() { + key_builder.append_value(key); + epoch_builder.append_value(*del_epoch); + } + + let schema = Arc::new(Schema::new(vec![ + Field::new("deleted_key", DataType::Binary, false), + Field::new("deleted_epoch", DataType::UInt64, false), + ])); + let batch = RecordBatch::try_new( + schema, vec![Arc::new(key_builder.finish()), Arc::new(epoch_builder.finish())] + )?; + + let path = self.tombstone_dir.join(Self::generate_tombstone_file_name(epoch)); + if let Err(e) = write_parquet_with_bloom_atomic(&path, &[batch], tomb_ndv) { + metrics.inc_io_errors(self.operator_id); + return Err(e); + } + self.tombstone_files.write().push(path); + } + + self.mem_ctrl.record_dec(size_to_release); + metrics.record_memory_usage(self.operator_id, self.mem_ctrl.usage_bytes()); + + self.is_spilling.store(false, Ordering::SeqCst); + self.spill_notify.notify_waiters(); + + if !self.immutable_tables.lock().is_empty() { + self.trigger_spill(); + } + Ok(()) + } + + pub(crate) fn execute_compact_sync( + self: &Arc, + is_major: bool, + metrics: &Arc + ) -> Result<()> { + let result = (|| -> Result<()> { + let files_to_merge = { + let df = self.data_files.read(); + if df.len() < 2 { return Ok(()); } + if is_major { df.clone() } else { df.iter().take(2).cloned().collect() } + }; + + let tombstone_snapshot = self.tombstones.read().clone(); + let compacted_watermark_epoch = files_to_merge.iter().map(|p| extract_epoch(p)).max().unwrap_or(0); + let new_path = self.data_dir.join(Self::generate_data_file_name(compacted_watermark_epoch)); + + let mut all_batches = Vec::new(); + let mut estimated_rows = 0; + + for path in &files_to_merge { + let file_epoch = extract_epoch(path); + let file = File::open(path).map_err(StateEngineError::IoError)?; + let mut reader = ParquetRecordBatchReaderBuilder::try_new(file)?.build()?; + while let Some(batch) = reader.next() { + let b = batch?; + if let Some(filtered) = filter_tombstones_from_batch(&b, &tombstone_snapshot, file_epoch)? { + estimated_rows += filtered.num_rows() as u64; + all_batches.push(filtered); + } + } + } + + if !all_batches.is_empty() { + if let Err(e) = write_parquet_with_bloom_atomic(&new_path, &all_batches, estimated_rows.max(100)) { + metrics.inc_io_errors(self.operator_id); + return Err(e); + } + let mut df = self.data_files.write(); + df.retain(|p| !files_to_merge.contains(p)); + df.push(new_path); + } else { + let mut df = self.data_files.write(); + df.retain(|p| !files_to_merge.contains(p)); + } + + for path in &files_to_merge { let _ = fs::remove_file(path); } + + // Watermark GC + { + let mut tg = self.tombstones.write(); + let mut memory_freed = 0; + + tg.retain(|key, deleted_epoch| { + if *deleted_epoch <= compacted_watermark_epoch { + memory_freed += key.len() + TOMBSTONE_ENTRY_OVERHEAD; + false + } else { true } + }); + + if memory_freed > 0 { + self.mem_ctrl.record_dec(memory_freed); + metrics.record_memory_usage(self.operator_id, self.mem_ctrl.usage_bytes()); + } + } + + { + let mut tf_guard = self.tombstone_files.write(); + tf_guard.retain(|p| { + if extract_epoch(p) <= compacted_watermark_epoch { + let _ = fs::remove_file(p); + return false; + } + true + }); + } + + Ok(()) + })(); + + self.is_compacting.store(false, Ordering::SeqCst); + result + } + + pub async fn restore_metadata(&self, safe_epoch: u64) -> Result> { + self.active_table.write().clear(); + self.immutable_tables.lock().retain(|(e, _)| *e <= safe_epoch); + + let cleanup_future = |files: &mut Vec| { + files.retain(|path| { + if extract_epoch(path) > safe_epoch { + let _ = fs::remove_file(path); + false + } else { true } + }); + }; + cleanup_future(&mut self.data_files.write()); + cleanup_future(&mut self.tombstone_files.write()); + + let tomb_paths = self.tombstone_files.read().clone(); + let loaded_tombstones = tokio::task::spawn_blocking(move || -> Result { + let mut map = HashMap::new(); + for path in tomb_paths { + let file = File::open(&path).map_err(StateEngineError::IoError)?; + let mut reader = ParquetRecordBatchReaderBuilder::try_new(file)?.build()?; + while let Some(batch) = reader.next() { + let batch = batch?; + let key_col = batch.column(0).as_any().downcast_ref::().unwrap(); + let ep_col = batch.column(1).as_any().downcast_ref::().unwrap(); + + for i in 0..key_col.len() { + let k = key_col.value(i).to_vec(); + let e = ep_col.value(i); + let current_max = map.get(&k).copied().unwrap_or(0); + if e > current_max { map.insert(k, e); } + } + } + } + Ok(map) + }).await.map_err(|_| StateEngineError::Corruption("Task Panicked".into()))??; + + let mut total_tombstone_mem = 0; + for key in loaded_tombstones.keys() { + total_tombstone_mem += key.len() + TOMBSTONE_ENTRY_OVERHEAD; + } + self.mem_ctrl.record_inc(total_tombstone_mem); + *self.tombstones.write() = loaded_tombstones.clone(); + + let data_paths = self.data_files.read().clone(); + let active_keys = tokio::task::spawn_blocking(move || -> Result> { + let mut keys = HashSet::new(); + for path in data_paths { + let file_epoch = extract_epoch(&path); + let file = File::open(&path).map_err(StateEngineError::IoError)?; + let builder = ParquetRecordBatchReaderBuilder::try_new(file)?; + let schema = builder.parquet_schema(); + let mask = ProjectionMask::leaves(schema, vec![schema.columns().len() - 1]); + let mut reader = builder.with_projection(mask).build()?; + + while let Some(batch) = reader.next() { + let batch = batch?; + let key_col = batch.column(0).as_any().downcast_ref::().unwrap(); + for i in 0..key_col.len() { + let k = key_col.value(i).to_vec(); + let is_active = match loaded_tombstones.get(&k) { + Some(del_ep) => *del_ep < file_epoch, + None => true, + }; + if is_active { keys.insert(k); } + } + } + } + Ok(keys) + }).await.map_err(|_| StateEngineError::Corruption("Task Panicked".into()))??; + + self.current_epoch.store(safe_epoch + 1, Ordering::Release); + Ok(active_keys) + } + + // ======================================================================== + // UUID-based file name generators + // ======================================================================== + + fn generate_data_file_name(epoch: u64) -> String { + format!("data-epoch-{}_uuid-{}.parquet", epoch, Uuid::now_v7()) + } + + fn generate_tombstone_file_name(epoch: u64) -> String { + format!("tombstone-epoch-{}_uuid-{}.parquet", epoch, Uuid::now_v7()) + } +} + +// ============================================================================ +// Internal helper functions +// ============================================================================ + +fn write_parquet_with_bloom_atomic(path: &Path, batches: &[RecordBatch], ndv: u64) -> Result<()> { + if batches.is_empty() { return Ok(()); } + let tmp = path.with_extension("tmp"); + { + let file = File::create(&tmp).map_err(StateEngineError::IoError)?; + let props = WriterProperties::builder() + .set_bloom_filter_enabled(true) + .set_bloom_filter_ndv(ndv) + .build(); + + let mut writer = ArrowWriter::try_new(&file, batches[0].schema(), Some(props))?; + for b in batches { writer.write(b)?; } + writer.close()?; + file.sync_all().map_err(StateEngineError::IoError)?; + } + fs::rename(&tmp, path).map_err(StateEngineError::IoError)?; + Ok(()) +} + +fn extract_epoch(path: &Path) -> u64 { + let name = path.file_name().unwrap_or_default().to_str().unwrap_or_default(); + if let Some(start) = name.find("-epoch-") { + let after = &name[start + 7..]; + if let Some(end) = after.find("_uuid-") { + return after[..end].parse().unwrap_or(0); + } + } + 0 +} + +fn inject_partition_key(batch: &RecordBatch, key: &[u8]) -> Result { + let mut fields = batch.schema().fields().to_vec(); + fields.push(Arc::new(Field::new(PARTITION_KEY_COL, DataType::Binary, false))); + let schema = Arc::new(Schema::new(fields)); + let key_array = Arc::new(BinaryArray::from_iter_values(std::iter::repeat_n(key, batch.num_rows()))); + let mut cols = batch.columns().to_vec(); + cols.push(key_array as Arc); + Ok(RecordBatch::try_new(schema, cols)?) +} + +fn filter_tombstones_from_batch( + batch: &RecordBatch, + tombstones: &TombstoneMap, + file_epoch: u64, +) -> Result> { + if tombstones.is_empty() { return Ok(Some(batch.clone())); } + let Ok(idx) = batch.schema().index_of(PARTITION_KEY_COL) else { return Ok(Some(batch.clone())); }; + + let key_col = batch.column(idx).as_any().downcast_ref::().unwrap(); + let mut mask_builder = BooleanBuilder::with_capacity(batch.num_rows()); + let mut has_valid = false; + + for i in 0..batch.num_rows() { + let key = key_col.value(i).to_vec(); + let keep = match tombstones.get(&key) { + Some(deleted_epoch) => *deleted_epoch < file_epoch, + None => true, + }; + mask_builder.append_value(keep); + if keep { has_valid = true; } + } + + if !has_valid { return Ok(None); } + let mask = mask_builder.finish(); + Ok(Some(arrow::compute::filter_record_batch(batch, &mask)?)) +} + +fn filter_and_strip_partition_key(batch: &RecordBatch, target_key: &[u8]) -> Result> { + let Ok(idx) = batch.schema().index_of(PARTITION_KEY_COL) else { return Ok(Some(batch.clone())); }; + let key_col = batch.column(idx).as_any().downcast_ref::().unwrap(); + let mut mask_builder = BooleanBuilder::with_capacity(batch.num_rows()); + for i in 0..batch.num_rows() { mask_builder.append_value(key_col.value(i) == target_key); } + let mask = mask_builder.finish(); + let filtered = arrow::compute::filter_record_batch(batch, &mask)?; + if filtered.num_rows() == 0 { return Ok(None); } + let mut proj: Vec = (0..filtered.num_columns()).collect(); + proj.retain(|&i| i != idx); + Ok(Some(filtered.project(&proj)?)) +} + +fn restore_memtable_from_injected_batches(batches: Vec) -> Result { + let mut m = MemTable::new(); + for batch in batches { + let idx = batch.schema().index_of(PARTITION_KEY_COL).unwrap(); + let key_col = batch.column(idx).as_any().downcast_ref::().unwrap(); + let pk = key_col.value(0).to_vec(); + let mut proj: Vec = (0..batch.num_columns()).collect(); + proj.retain(|&i| i != idx); + m.entry(pk).or_default().push(batch.project(&proj)?); + } + Ok(m) +} diff --git a/src/server/initializer.rs b/src/server/initializer.rs index 785321b8..c1e11569 100644 --- a/src/server/initializer.rs +++ b/src/server/initializer.rs @@ -158,7 +158,7 @@ fn initialize_python_service(config: &GlobalConfig) -> Result<()> { fn initialize_job_manager(config: &GlobalConfig) -> Result<()> { use crate::runtime::streaming::factory::OperatorFactory; use crate::runtime::streaming::factory::Registry; - use crate::runtime::streaming::job::JobManager; + use crate::runtime::streaming::job::{JobManager, StateConfig}; use std::sync::Arc; let registry = Arc::new(Registry::new()); @@ -168,7 +168,11 @@ fn initialize_job_manager(config: &GlobalConfig) -> Result<()> { .max_memory_bytes .unwrap_or(256 * 1024 * 1024); - JobManager::init(factory, max_memory_bytes).context("JobManager service failed to start")?; + let state_base_dir = std::env::temp_dir().join("function-stream").join("state"); + let state_config = StateConfig::default(); + + JobManager::init(factory, max_memory_bytes, state_base_dir, state_config) + .context("JobManager service failed to start")?; Ok(()) } diff --git a/src/storage/stream_catalog/manager.rs b/src/storage/stream_catalog/manager.rs index 3804a95a..9691e991 100644 --- a/src/storage/stream_catalog/manager.rs +++ b/src/storage/stream_catalog/manager.rs @@ -17,7 +17,7 @@ use datafusion::common::{Result as DFResult, internal_err, plan_err}; use prost::Message; use protocol::function_stream_graph::FsProgram; use protocol::storage::{self as pb, table_definition}; -use tracing::{info, warn}; +use tracing::{info, warn, debug}; use unicase::UniCase; use crate::sql::common::constants::sql_field; @@ -88,6 +88,7 @@ impl CatalogManager { table_name: &str, fs_program: &FsProgram, comment: &str, + checkpoint_interval_ms: u64, ) -> DFResult<()> { let program_bytes = fs_program.encode_to_vec(); let def = pb::StreamingTableDefinition { @@ -95,11 +96,13 @@ impl CatalogManager { created_at_millis: chrono::Utc::now().timestamp_millis(), fs_program_bytes: program_bytes, comment: comment.to_string(), + checkpoint_interval_ms, + latest_checkpoint_epoch: 0, }; let payload = def.encode_to_vec(); let key = Self::build_streaming_job_key(table_name); self.store.put(&key, payload)?; - info!(table = %table_name, "Streaming job definition persisted"); + info!(table = %table_name, interval_ms = checkpoint_interval_ms, "Streaming job definition persisted"); Ok(()) } @@ -110,7 +113,40 @@ impl CatalogManager { Ok(()) } - pub fn load_streaming_job_definitions(&self) -> DFResult> { + /// Persist the globally-completed checkpoint epoch after all operators ACK. + /// Only advances forward; stale epochs are silently ignored. + pub fn commit_job_checkpoint(&self, table_name: &str, epoch: u64) -> DFResult<()> { + let key = Self::build_streaming_job_key(table_name); + + let current_payload = self.store.get(&key)?.ok_or_else(|| { + datafusion::common::DataFusionError::Plan(format!( + "Cannot commit checkpoint: Streaming job '{}' not found in catalog", + table_name + )) + })?; + + let mut def = + pb::StreamingTableDefinition::decode(current_payload.as_slice()).map_err(|e| { + datafusion::common::DataFusionError::Execution(format!( + "Protobuf decode error: {}", + e + )) + })?; + + if epoch > def.latest_checkpoint_epoch { + def.latest_checkpoint_epoch = epoch; + let new_payload = def.encode_to_vec(); + self.store.put(&key, new_payload)?; + debug!(table = %table_name, epoch = epoch, "Checkpoint metadata committed to Catalog"); + } + + Ok(()) + } + + /// Returns (table_name, program, checkpoint_interval_ms, latest_checkpoint_epoch). + pub fn load_streaming_job_definitions( + &self, + ) -> DFResult> { let records = self.store.scan_prefix(STREAMING_JOB_KEY_PREFIX)?; let mut out = Vec::with_capacity(records.len()); for (key, payload) in records { @@ -136,7 +172,12 @@ impl CatalogManager { continue; } }; - out.push((def.table_name, program)); + out.push(( + def.table_name, + program, + def.checkpoint_interval_ms, + def.latest_checkpoint_epoch, + )); } Ok(out) } @@ -522,12 +563,28 @@ pub fn restore_streaming_jobs_from_store() { let mut restored = 0usize; let mut failed = 0usize; - for (table_name, fs_program) in definitions { + for (table_name, fs_program, interval_ms, latest_epoch) in definitions { let jm = job_manager.clone(); let name = table_name.clone(); - match rt.block_on(jm.submit_job(name.clone(), fs_program)) { + + let custom_interval = if interval_ms > 0 { + Some(interval_ms) + } else { + None + }; + let recovery_epoch = if latest_epoch > 0 { + Some(latest_epoch) + } else { + None + }; + + match rt.block_on(jm.submit_job(name.clone(), fs_program, custom_interval, recovery_epoch)) + { Ok(job_id) => { - info!(table = %table_name, job_id = %job_id, "Streaming job restored"); + info!( + table = %table_name, job_id = %job_id, + epoch = latest_epoch, "Streaming job restored" + ); restored += 1; } Err(e) => { From 7547a3fbb633e2f910995f73ea6b8751aca5f41c Mon Sep 17 00:00:00 2001 From: luoluoyuyu Date: Sun, 12 Apr 2026 21:57:19 +0800 Subject: [PATCH 02/26] feat(streaming): persist Join operator state with LSM-Tree and 3-phase watermark harvesting - Extend TaskContext with state_dir, memory_controller, io_manager, and safe_epoch to bridge operators with the state engine - Refactor JoinWithExpirationOperator: replace in-memory VecDeque with PersistentStateBuffer backed by OperatorStateStore, using composite keys [Side(1B) + Timestamp(8B BE)] and BTreeSet timeline index - Refactor InstantJoinOperator: replace in-memory BTreeMap with LSM-Tree persistence, split process_watermark into 3-phase pipeline (harvest -> compute -> cleanup) to eliminate interleaved mutable/immutable borrow conflicts - Both operators now support on_start recovery via restore_metadata and snapshot_state via snapshot_epoch for exactly-once semantics Made-with: Cursor --- src/coordinator/execution/executor.rs | 13 +- src/runtime/streaming/api/context.rs | 27 ++ src/runtime/streaming/job/job_manager.rs | 26 +- .../operators/joins/join_instance.rs | 321 ++++++++++++------ .../operators/joins/join_with_expiration.rs | 172 ++++++++-- src/runtime/streaming/state/error.rs | 2 +- src/runtime/streaming/state/io_manager.rs | 17 +- src/runtime/streaming/state/mod.rs | 8 +- src/runtime/streaming/state/operator_state.rs | 240 +++++++++---- src/storage/stream_catalog/manager.rs | 6 +- 10 files changed, 605 insertions(+), 227 deletions(-) diff --git a/src/coordinator/execution/executor.rs b/src/coordinator/execution/executor.rs index 399ab775..7dc3c0ff 100644 --- a/src/coordinator/execution/executor.rs +++ b/src/coordinator/execution/executor.rs @@ -337,15 +337,16 @@ impl PlanVisitor for Executor { custom_interval.unwrap_or(0), ) .map_err(|e| { - ExecuteError::Internal(format!( - "Streaming job persistence failed: {e}", - )) + ExecuteError::Internal(format!("Streaming job persistence failed: {e}",)) })?; let job_id = tokio::task::block_in_place(|| { - tokio::runtime::Handle::current().block_on( - job_manager.submit_job(job_id, fs_program, custom_interval, None), - ) + tokio::runtime::Handle::current().block_on(job_manager.submit_job( + job_id, + fs_program, + custom_interval, + None, + )) }) .map_err(|e| ExecuteError::Internal(format!("Failed to submit streaming job: {e}")))?; diff --git a/src/runtime/streaming/api/context.rs b/src/runtime/streaming/api/context.rs index f9dc805e..f7c47a8f 100644 --- a/src/runtime/streaming/api/context.rs +++ b/src/runtime/streaming/api/context.rs @@ -10,6 +10,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::path::PathBuf; use std::sync::Arc; use std::time::{Duration, SystemTime}; @@ -19,6 +20,7 @@ use arrow_array::RecordBatch; use crate::runtime::streaming::memory::MemoryPool; use crate::runtime::streaming::network::endpoint::PhysicalSender; use crate::runtime::streaming::protocol::event::{StreamEvent, TrackedEvent}; +use crate::runtime::streaming::state::{IoManager, MemoryController}; #[derive(Debug, Clone)] pub struct TaskContextConfig { @@ -61,6 +63,18 @@ pub struct TaskContext { /// Subtask-level tunables. config: TaskContextConfig, + + /// Root directory for operator state persistence (LSM-Tree data/tombstone files). + pub state_dir: PathBuf, + + /// Shared memory controller for state engine back-pressure. + pub memory_controller: Arc, + + /// I/O thread pool handle for background spill/compaction. + pub io_manager: IoManager, + + /// Last globally-committed safe epoch for crash recovery. + safe_epoch: u64, } impl TaskContext { @@ -71,6 +85,10 @@ impl TaskContext { parallelism: u32, downstream_senders: Vec, memory_pool: Arc, + memory_controller: Arc, + io_manager: IoManager, + state_dir: PathBuf, + safe_epoch: u64, ) -> Self { let task_name = format!( "Task-[{}]-Pipe[{}]-Sub[{}/{}]", @@ -87,9 +105,18 @@ impl TaskContext { memory_pool, current_watermark: None, config: TaskContextConfig::default(), + state_dir, + memory_controller, + io_manager, + safe_epoch, } } + #[inline] + pub fn latest_safe_epoch(&self) -> u64 { + self.safe_epoch + } + #[inline] pub fn config(&self) -> &TaskContextConfig { &self.config diff --git a/src/runtime/streaming/job/job_manager.rs b/src/runtime/streaming/job/job_manager.rs index 011a912e..a7c982c4 100644 --- a/src/runtime/streaming/job/job_manager.rs +++ b/src/runtime/streaming/job/job_manager.rs @@ -19,7 +19,7 @@ use anyhow::{Context, Result, anyhow, bail, ensure}; use tokio::sync::mpsc; use tokio::task::JoinHandle as TokioJoinHandle; use tokio_stream::wrappers::ReceiverStream; -use tracing::{error, info, warn, debug}; +use tracing::{debug, error, info, warn}; use protocol::function_stream_graph::{ChainedOperator, FsProgram}; @@ -34,7 +34,7 @@ use crate::runtime::streaming::job::models::{ }; use crate::runtime::streaming::memory::MemoryPool; use crate::runtime::streaming::network::endpoint::{BoxedEventStream, PhysicalSender}; -use crate::runtime::streaming::protocol::control::{ControlCommand, StopMode, JobMasterEvent}; +use crate::runtime::streaming::protocol::control::{ControlCommand, JobMasterEvent, StopMode}; use crate::runtime::streaming::protocol::event::CheckpointBarrier; use crate::runtime::streaming::state::{IoManager, IoPool, MemoryController, NoopMetricsCollector}; use crate::storage::stream_catalog::CatalogManager; @@ -132,7 +132,8 @@ impl JobManager { state_config.max_background_spills, state_config.max_background_compactions, metrics, - ).context("Failed to initialize state engine I/O pool")?; + ) + .context("Failed to initialize state engine I/O pool")?; Ok(Self { active_jobs: Arc::new(RwLock::new(HashMap::new())), @@ -153,7 +154,12 @@ impl JobManager { state_config: StateConfig, ) -> Result<()> { GLOBAL_JOB_MANAGER - .set(Arc::new(Self::new(factory, memory_bytes, state_base_dir, state_config)?)) + .set(Arc::new(Self::new( + factory, + memory_bytes, + state_base_dir, + state_config, + )?)) .map_err(|_| anyhow!("JobManager singleton already initialized")) } @@ -217,8 +223,8 @@ impl JobManager { pipelines.insert(pipeline_id, pipeline); } - let interval_ms = custom_checkpoint_interval_ms - .unwrap_or(self.state_config.checkpoint_interval_ms); + let interval_ms = + custom_checkpoint_interval_ms.unwrap_or(self.state_config.checkpoint_interval_ms); self.spawn_checkpoint_coordinator( job_id.clone(), @@ -425,9 +431,9 @@ impl JobManager { pipeline_id: u32, operators: &[ChainedOperator], edge_manager: &mut EdgeManager, - _job_state_dir: &Path, + job_state_dir: &Path, _job_master_tx: mpsc::Sender, - _recovery_epoch: u64, + recovery_epoch: u64, ) -> Result<(PhysicalPipeline, bool)> { let (raw_inboxes, raw_outboxes) = edge_manager.take_endpoints(pipeline_id).with_context(|| { @@ -479,6 +485,10 @@ impl JobManager { parallelism, physical_outboxes, Arc::clone(&self.memory_pool), + Arc::clone(&self.memory_controller), + self.io_manager_client.clone(), + job_state_dir.to_path_buf(), + recovery_epoch, ); let runner = if let Some(source) = chain.source { diff --git a/src/runtime/streaming/operators/joins/join_instance.rs b/src/runtime/streaming/operators/joins/join_instance.rs index 75513542..bfb6c416 100644 --- a/src/runtime/streaming/operators/joins/join_instance.rs +++ b/src/runtime/streaming/operators/joins/join_instance.rs @@ -11,9 +11,8 @@ // limitations under the License. use anyhow::{Result, anyhow}; -use arrow::compute::{max, min, partition, sort_to_indices, take}; +use arrow::compute::{concat_batches, max, min, partition, sort_to_indices, take}; use arrow_array::{RecordBatch, TimestampNanosecondArray}; -use datafusion::execution::SendableRecordBatchStream; use datafusion::execution::context::SessionContext; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::physical_plan::ExecutionPlan; @@ -21,80 +20,79 @@ use datafusion_proto::physical_plan::AsExecutionPlan; use datafusion_proto::protobuf::PhysicalPlanNode; use futures::StreamExt; use prost::Message; -use std::collections::BTreeMap; +use std::collections::BTreeSet; use std::sync::{Arc, RwLock}; -use std::time::SystemTime; -use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel}; -use tracing::warn; +use std::time::UNIX_EPOCH; +use tracing::{info, warn}; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; use crate::runtime::streaming::api::operator::Operator; use crate::runtime::streaming::factory::Registry; -use crate::sql::common::constants::mem_exec_join_side; -use crate::sql::common::{CheckpointBarrier, FsSchema, FsSchemaRef, Watermark, from_nanos}; +use crate::runtime::streaming::state::OperatorStateStore; +use crate::sql::common::{CheckpointBarrier, FsSchema, FsSchemaRef, Watermark}; use crate::sql::physical::{StreamingDecodingContext, StreamingExtensionCodec}; use async_trait::async_trait; use protocol::function_stream_graph::JoinOperator; #[derive(Debug, Copy, Clone, Eq, PartialEq)] enum JoinSide { - Left, - Right, + Left = 0, + Right = 1, } -impl JoinSide { - #[allow(dead_code)] - fn name(&self) -> &'static str { - match self { - JoinSide::Left => mem_exec_join_side::LEFT, - JoinSide::Right => mem_exec_join_side::RIGHT, - } - } -} +// ============================================================================ +// Lightweight state index: composite key [Side(1B)] + [Timestamp(8B BE)] +// ============================================================================ -struct JoinInstance { - left_tx: UnboundedSender, - right_tx: UnboundedSender, - result_stream: SendableRecordBatchStream, +struct InstantStateIndex { + side: JoinSide, + active_timestamps: BTreeSet, } -impl JoinInstance { - fn feed_data(&self, batch: RecordBatch, side: JoinSide) -> Result<()> { - match side { - JoinSide::Left => self - .left_tx - .send(batch) - .map_err(|e| anyhow!("Left send err: {}", e)), - JoinSide::Right => self - .right_tx - .send(batch) - .map_err(|e| anyhow!("Right send err: {}", e)), +impl InstantStateIndex { + fn new(side: JoinSide) -> Self { + Self { + side, + active_timestamps: BTreeSet::new(), } } - async fn close_and_drain(self) -> Result> { - drop(self.left_tx); - drop(self.right_tx); - - let mut outputs = Vec::new(); - let mut stream = self.result_stream; + fn build_key(side: JoinSide, ts_nanos: u64) -> Vec { + let mut key = Vec::with_capacity(9); + key.push(side as u8); + key.extend_from_slice(&ts_nanos.to_be_bytes()); + key + } - while let Some(result_batch) = stream.next().await { - outputs.push(result_batch?); + fn extract_timestamp(key: &[u8]) -> Option { + if key.len() == 9 { + let mut ts_bytes = [0u8; 8]; + ts_bytes.copy_from_slice(&key[1..]); + Some(u64::from_be_bytes(ts_bytes)) + } else { + None } - - Ok(outputs) } } +// ============================================================================ +// InstantJoinOperator (persistent state refactor) +// ============================================================================ + pub struct InstantJoinOperator { left_input_schema: FsSchemaRef, right_input_schema: FsSchemaRef, - active_joins: BTreeMap, - left_receiver_hook: Arc>>>, - right_receiver_hook: Arc>>>, + left_schema: FsSchemaRef, + right_schema: FsSchemaRef, + + left_passer: Arc>>, + right_passer: Arc>>, join_exec_plan: Arc, + + left_state: InstantStateIndex, + right_state: InstantStateIndex, + state_store: Option>, } impl InstantJoinOperator { @@ -105,32 +103,26 @@ impl InstantJoinOperator { } } - fn get_or_create_join_instance(&mut self, time: SystemTime) -> Result<&mut JoinInstance> { - use std::collections::btree_map::Entry; + async fn compute_pair( + &mut self, + left: RecordBatch, + right: RecordBatch, + ) -> Result> { + self.left_passer.write().unwrap().replace(left); + self.right_passer.write().unwrap().replace(right); - if let Entry::Vacant(e) = self.active_joins.entry(time) { - let (left_tx, left_rx) = unbounded_channel(); - let (right_tx, right_rx) = unbounded_channel(); + self.join_exec_plan.reset().map_err(|e| anyhow!("{e}"))?; - *self.left_receiver_hook.write().unwrap() = Some(left_rx); - *self.right_receiver_hook.write().unwrap() = Some(right_rx); + let mut result_stream = self + .join_exec_plan + .execute(0, SessionContext::new().task_ctx()) + .map_err(|e| anyhow!("{e}"))?; - self.join_exec_plan.reset().map_err(|e| anyhow!("{e}"))?; - let result_stream = self - .join_exec_plan - .execute(0, SessionContext::new().task_ctx()) - .map_err(|e| anyhow!("{e}"))?; - - e.insert(JoinInstance { - left_tx, - right_tx, - result_stream, - }); + let mut outputs = Vec::new(); + while let Some(batch) = result_stream.next().await { + outputs.push(batch.map_err(|e| anyhow!("{e}"))?); } - - self.active_joins - .get_mut(&time) - .ok_or_else(|| anyhow!("join instance missing after insert")) + Ok(outputs) } async fn process_side_internal( @@ -142,6 +134,10 @@ impl InstantJoinOperator { if batch.num_rows() == 0 { return Ok(()); } + let store = self + .state_store + .as_ref() + .expect("State store not initialized"); let time_column = batch .column(self.input_schema(side).timestamp_index) @@ -152,19 +148,28 @@ impl InstantJoinOperator { let min_timestamp = min(time_column).ok_or_else(|| anyhow!("empty timestamp column"))?; let max_timestamp = max(time_column).ok_or_else(|| anyhow!("empty timestamp column"))?; - if let Some(watermark) = ctx.current_watermark() - && watermark > from_nanos(min_timestamp as u128) - { - warn!("Dropped late batch from {:?} before watermark", side); - return Ok(()); + if let Some(watermark) = ctx.current_watermark() { + let watermark_nanos = watermark.duration_since(UNIX_EPOCH).unwrap().as_nanos() as i64; + if watermark_nanos > min_timestamp { + warn!("Dropped late batch from {:?} before watermark", side); + return Ok(()); + } } let unkeyed_batch = self.input_schema(side).unkeyed_batch(&batch)?; + let state_index = match side { + JoinSide::Left => &mut self.left_state, + JoinSide::Right => &mut self.right_state, + }; if max_timestamp == min_timestamp { - let time_key = from_nanos(max_timestamp as u128); - let join_instance = self.get_or_create_join_instance(time_key)?; - join_instance.feed_data(unkeyed_batch, side)?; + let ts_nanos = max_timestamp as u64; + let key = InstantStateIndex::build_key(side, ts_nanos); + store + .put(key, unkeyed_batch) + .await + .map_err(|e| anyhow!("{e}"))?; + state_index.active_timestamps.insert(ts_nanos); return Ok(()); } @@ -179,16 +184,21 @@ impl InstantJoinOperator { let typed_timestamps = sorted_timestamps .as_any() .downcast_ref::() - .ok_or_else(|| anyhow!("sorted timestamps downcast failed"))?; + .unwrap(); + let ranges = partition(std::slice::from_ref(&sorted_timestamps)) .unwrap() .ranges(); for range in ranges { let sub_batch = sorted_batch.slice(range.start, range.end - range.start); - let time_key = from_nanos(typed_timestamps.value(range.start) as u128); - let join_instance = self.get_or_create_join_instance(time_key)?; - join_instance.feed_data(sub_batch, side)?; + let ts_nanos = typed_timestamps.value(range.start) as u64; + let key = InstantStateIndex::build_key(side, ts_nanos); + store + .put(key, sub_batch) + .await + .map_err(|e| anyhow!("{e}"))?; + state_index.active_timestamps.insert(ts_nanos); } Ok(()) @@ -201,7 +211,39 @@ impl Operator for InstantJoinOperator { "InstantJoin" } - async fn on_start(&mut self, _ctx: &mut TaskContext) -> Result<()> { + async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { + let store = OperatorStateStore::new( + ctx.pipeline_id, + ctx.state_dir.clone(), + ctx.memory_controller.clone(), + ctx.io_manager.clone(), + ) + .map_err(|e| anyhow!("Failed to init state store: {e}"))?; + + let safe_epoch = ctx.latest_safe_epoch(); + let active_keys = store + .restore_metadata(safe_epoch) + .await + .map_err(|e| anyhow!("State recovery failed: {e}"))?; + + for key in active_keys { + if let Some(ts) = InstantStateIndex::extract_timestamp(&key) { + if key[0] == JoinSide::Left as u8 { + self.left_state.active_timestamps.insert(ts); + } else if key[0] == JoinSide::Right as u8 { + self.right_state.active_timestamps.insert(ts); + } + } + } + + info!( + pipeline_id = ctx.pipeline_id, + restored_left = self.left_state.active_timestamps.len(), + restored_right = self.right_state.active_timestamps.len(), + "Instant Join Operator recovered state." + ); + + self.state_store = Some(store); Ok(()) } @@ -228,24 +270,76 @@ impl Operator for InstantJoinOperator { let Watermark::EventTime(current_time) = watermark else { return Ok(vec![]); }; - let mut emit_outputs = Vec::new(); + let store = self.state_store.clone().unwrap(); + let cutoff_nanos = current_time.duration_since(UNIX_EPOCH).unwrap().as_nanos() as u64; + + let mut all_active_ts = BTreeSet::new(); + all_active_ts.extend(self.left_state.active_timestamps.iter()); + all_active_ts.extend(self.right_state.active_timestamps.iter()); + + let expired_ts: Vec = all_active_ts + .into_iter() + .filter(|&ts| ts < cutoff_nanos) + .collect(); + + if expired_ts.is_empty() { + return Ok(vec![]); + } - let mut expired_times = Vec::new(); - for key in self.active_joins.keys() { - if *key < current_time { - expired_times.push(*key); + // Phase 1: Harvest — extract all expired timestamp data from LSM-Tree + let mut pending_pairs: Vec<(u64, RecordBatch, RecordBatch)> = + Vec::with_capacity(expired_ts.len()); + + for &ts in &expired_ts { + let left_key = InstantStateIndex::build_key(JoinSide::Left, ts); + let right_key = InstantStateIndex::build_key(JoinSide::Right, ts); + + let left_batches = store + .get_batches(&left_key) + .await + .map_err(|e| anyhow!("{e}"))?; + let right_batches = store + .get_batches(&right_key) + .await + .map_err(|e| anyhow!("{e}"))?; + + let left_input = if left_batches.is_empty() { + RecordBatch::new_empty(self.left_schema.schema.clone()) } else { - break; - } + concat_batches(&self.left_schema.schema, left_batches.iter())? + }; + let right_input = if right_batches.is_empty() { + RecordBatch::new_empty(self.right_schema.schema.clone()) + } else { + concat_batches(&self.right_schema.schema, right_batches.iter())? + }; + + pending_pairs.push((ts, left_input, right_input)); } - for time_key in expired_times { - if let Some(join_instance) = self.active_joins.remove(&time_key) { - let joined_batches = join_instance.close_and_drain().await?; - for batch in joined_batches { - emit_outputs.push(StreamOutput::Forward(batch)); - } + // Phase 2: Compute — all data extracted, no store reference held + let mut emit_outputs = Vec::new(); + + for (_, left_input, right_input) in pending_pairs { + if left_input.num_rows() == 0 && right_input.num_rows() == 0 { + continue; } + let results = self.compute_pair(left_input, right_input).await?; + for batch in results { + emit_outputs.push(StreamOutput::Forward(batch)); + } + } + + // Phase 3: Cleanup — tombstone LSM-Tree entries and update in-memory index + for ts in expired_ts { + let left_key = InstantStateIndex::build_key(JoinSide::Left, ts); + let right_key = InstantStateIndex::build_key(JoinSide::Right, ts); + store.remove_batches(left_key).map_err(|e| anyhow!("{e}"))?; + store + .remove_batches(right_key) + .map_err(|e| anyhow!("{e}"))?; + self.left_state.active_timestamps.remove(&ts); + self.right_state.active_timestamps.remove(&ts); } Ok(emit_outputs) @@ -253,13 +347,22 @@ impl Operator for InstantJoinOperator { async fn snapshot_state( &mut self, - _barrier: CheckpointBarrier, + barrier: CheckpointBarrier, _ctx: &mut TaskContext, ) -> Result<()> { + self.state_store + .as_ref() + .unwrap() + .snapshot_epoch(barrier.epoch as u64) + .map_err(|e| anyhow!("Snapshot failed: {e}"))?; Ok(()) } } +// ============================================================================ +// Constructor +// ============================================================================ + pub struct InstantJoinConstructor; impl InstantJoinConstructor { @@ -268,21 +371,23 @@ impl InstantJoinConstructor { config: JoinOperator, registry: Arc, ) -> anyhow::Result { - let join_physical_plan_node = PhysicalPlanNode::decode(&mut config.join_plan.as_slice())?; - let left_input_schema: Arc = Arc::new(config.left_schema.unwrap().try_into()?); let right_input_schema: Arc = Arc::new(config.right_schema.unwrap().try_into()?); - let left_receiver_hook = Arc::new(RwLock::new(None)); - let right_receiver_hook = Arc::new(RwLock::new(None)); + let left_schema = Arc::new(left_input_schema.schema_without_keys()?); + let right_schema = Arc::new(right_input_schema.schema_without_keys()?); + + let left_passer = Arc::new(RwLock::new(None)); + let right_passer = Arc::new(RwLock::new(None)); let codec = StreamingExtensionCodec { - context: StreamingDecodingContext::LockedJoinStream { - left: left_receiver_hook.clone(), - right: right_receiver_hook.clone(), + context: StreamingDecodingContext::LockedJoinPair { + left: left_passer.clone(), + right: right_passer.clone(), }, }; + let join_physical_plan_node = PhysicalPlanNode::decode(&mut config.join_plan.as_slice())?; let join_exec_plan = join_physical_plan_node.try_into_physical_plan( registry.as_ref(), &RuntimeEnvBuilder::new().build()?, @@ -292,10 +397,14 @@ impl InstantJoinConstructor { Ok(InstantJoinOperator { left_input_schema, right_input_schema, - active_joins: BTreeMap::new(), - left_receiver_hook, - right_receiver_hook, + left_schema, + right_schema, + left_passer, + right_passer, join_exec_plan, + left_state: InstantStateIndex::new(JoinSide::Left), + right_state: InstantStateIndex::new(JoinSide::Right), + state_store: None, }) } } diff --git a/src/runtime/streaming/operators/joins/join_with_expiration.rs b/src/runtime/streaming/operators/joins/join_with_expiration.rs index 60bbe7e3..4d579715 100644 --- a/src/runtime/streaming/operators/joins/join_with_expiration.rs +++ b/src/runtime/streaming/operators/joins/join_with_expiration.rs @@ -19,15 +19,16 @@ use datafusion::physical_plan::ExecutionPlan; use datafusion_proto::{physical_plan::AsExecutionPlan, protobuf::PhysicalPlanNode}; use futures::StreamExt; use prost::Message; -use std::collections::VecDeque; +use std::collections::BTreeSet; use std::sync::{Arc, RwLock}; -use std::time::{Duration, SystemTime}; -use tracing::warn; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use tracing::{info, warn}; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; use crate::runtime::streaming::api::operator::Operator; use crate::runtime::streaming::factory::Registry; +use crate::runtime::streaming::state::OperatorStateStore; use crate::sql::common::{CheckpointBarrier, FsSchema, Watermark}; use crate::sql::physical::{StreamingDecodingContext, StreamingExtensionCodec}; use async_trait::async_trait; @@ -35,49 +36,91 @@ use protocol::function_stream_graph::JoinOperator; #[derive(Debug, Copy, Clone, Eq, PartialEq)] enum JoinSide { - Left, - Right, + Left = 0, + Right = 1, } // ============================================================================ +// Persistent state buffer: composite key [Side(1B)] + [Timestamp(8B BE)] // ============================================================================ -struct StateBuffer { - batches: VecDeque<(SystemTime, RecordBatch)>, +struct PersistentStateBuffer { + side: JoinSide, ttl: Duration, + active_timestamps: BTreeSet, } -impl StateBuffer { - fn new(ttl: Duration) -> Self { +impl PersistentStateBuffer { + fn new(side: JoinSide, ttl: Duration) -> Self { Self { - batches: VecDeque::new(), + side, ttl, + active_timestamps: BTreeSet::new(), } } - fn insert(&mut self, batch: RecordBatch, time: SystemTime) { - self.batches.push_back((time, batch)); + fn build_key(side: JoinSide, ts_nanos: u64) -> Vec { + let mut key = Vec::with_capacity(9); + key.push(side as u8); + key.extend_from_slice(&ts_nanos.to_be_bytes()); + key } - fn expire(&mut self, current_time: SystemTime) { - let cutoff = current_time - .checked_sub(self.ttl) - .unwrap_or(SystemTime::UNIX_EPOCH); - while let Some((time, _)) = self.batches.front() { - if *time < cutoff { - self.batches.pop_front(); - } else { - break; - } + fn extract_timestamp(key: &[u8]) -> Option { + if key.len() == 9 { + let mut ts_bytes = [0u8; 8]; + ts_bytes.copy_from_slice(&key[1..]); + Some(u64::from_be_bytes(ts_bytes)) + } else { + None } } - fn get_all_batches(&self) -> Vec { - self.batches.iter().map(|(_, b)| b.clone()).collect() + async fn insert( + &mut self, + batch: RecordBatch, + time: SystemTime, + store: &Arc, + ) -> Result<()> { + let ts_nanos = time.duration_since(UNIX_EPOCH).unwrap().as_nanos() as u64; + self.active_timestamps.insert(ts_nanos); + let key = Self::build_key(self.side, ts_nanos); + store.put(key, batch).await.map_err(|e| anyhow!("{e}")) + } + + fn expire(&mut self, current_time: SystemTime, store: &Arc) -> Result<()> { + let cutoff = current_time.checked_sub(self.ttl).unwrap_or(UNIX_EPOCH); + let cutoff_nanos = cutoff.duration_since(UNIX_EPOCH).unwrap().as_nanos() as u64; + + let expired_ts: Vec = self + .active_timestamps + .iter() + .take_while(|&&ts| ts < cutoff_nanos) + .copied() + .collect(); + + for ts in expired_ts { + let key = Self::build_key(self.side, ts); + store.remove_batches(key).map_err(|e| anyhow!("{e}"))?; + self.active_timestamps.remove(&ts); + } + + Ok(()) + } + + async fn get_all_batches(&self, store: &Arc) -> Result> { + let mut all_batches = Vec::new(); + for &ts in &self.active_timestamps { + let key = Self::build_key(self.side, ts); + let batches = store.get_batches(&key).await.map_err(|e| anyhow!("{e}"))?; + all_batches.extend(batches); + } + Ok(all_batches) } } // ============================================================================ +// JoinWithExpirationOperator // ============================================================================ pub struct JoinWithExpirationOperator { @@ -90,8 +133,9 @@ pub struct JoinWithExpirationOperator { right_passer: Arc>>, join_exec_plan: Arc, - left_state: StateBuffer, - right_state: StateBuffer, + left_state: PersistentStateBuffer, + right_state: PersistentStateBuffer, + state_store: Option>, } impl JoinWithExpirationOperator { @@ -133,18 +177,30 @@ impl JoinWithExpirationOperator { ctx: &mut TaskContext, ) -> Result> { let current_time = ctx.current_watermark().unwrap_or_else(SystemTime::now); + let store = self + .state_store + .as_ref() + .expect("State store not initialized"); - self.left_state.expire(current_time); - self.right_state.expire(current_time); + self.left_state.expire(current_time, store)?; + self.right_state.expire(current_time, store)?; match side { - JoinSide::Left => self.left_state.insert(batch.clone(), current_time), - JoinSide::Right => self.right_state.insert(batch.clone(), current_time), + JoinSide::Left => { + self.left_state + .insert(batch.clone(), current_time, store) + .await? + } + JoinSide::Right => { + self.right_state + .insert(batch.clone(), current_time, store) + .await? + } } let opposite_batches = match side { - JoinSide::Left => self.right_state.get_all_batches(), - JoinSide::Right => self.left_state.get_all_batches(), + JoinSide::Left => self.right_state.get_all_batches(store).await?, + JoinSide::Right => self.left_state.get_all_batches(store).await?, }; if opposite_batches.is_empty() { @@ -182,7 +238,39 @@ impl Operator for JoinWithExpirationOperator { "JoinWithExpiration" } - async fn on_start(&mut self, _ctx: &mut TaskContext) -> Result<()> { + async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { + let store = OperatorStateStore::new( + ctx.pipeline_id, + ctx.state_dir.clone(), + ctx.memory_controller.clone(), + ctx.io_manager.clone(), + ) + .map_err(|e| anyhow!("Failed to init state store: {e}"))?; + + let safe_epoch = ctx.latest_safe_epoch(); + let active_keys = store + .restore_metadata(safe_epoch) + .await + .map_err(|e| anyhow!("State recovery failed: {e}"))?; + + for key in active_keys { + if let Some(ts) = PersistentStateBuffer::extract_timestamp(&key) { + if key[0] == JoinSide::Left as u8 { + self.left_state.active_timestamps.insert(ts); + } else if key[0] == JoinSide::Right as u8 { + self.right_state.active_timestamps.insert(ts); + } + } + } + + info!( + pipeline_id = ctx.pipeline_id, + restored_left = self.left_state.active_timestamps.len(), + restored_right = self.right_state.active_timestamps.len(), + "Join Operator restored state from LSM-Tree." + ); + + self.state_store = Some(store); Ok(()) } @@ -210,9 +298,19 @@ impl Operator for JoinWithExpirationOperator { async fn snapshot_state( &mut self, - _barrier: CheckpointBarrier, + barrier: CheckpointBarrier, _ctx: &mut TaskContext, ) -> Result<()> { + let store = self + .state_store + .as_ref() + .expect("State store not initialized"); + + store + .snapshot_epoch(barrier.epoch as u64) + .map_err(|e| anyhow!("Snapshot failed: {e}"))?; + + info!(epoch = barrier.epoch, "Join Operator snapshotted state."); Ok(()) } @@ -222,6 +320,7 @@ impl Operator for JoinWithExpirationOperator { } // ============================================================================ +// Constructor // ============================================================================ pub struct JoinWithExpirationConstructor; @@ -273,8 +372,9 @@ impl JoinWithExpirationConstructor { left_passer, right_passer, join_exec_plan, - left_state: StateBuffer::new(ttl), - right_state: StateBuffer::new(ttl), + left_state: PersistentStateBuffer::new(JoinSide::Left, ttl), + right_state: PersistentStateBuffer::new(JoinSide::Right, ttl), + state_store: None, }) } } diff --git a/src/runtime/streaming/state/error.rs b/src/runtime/streaming/state/error.rs index e04a022e..10c3c7c5 100644 --- a/src/runtime/streaming/state/error.rs +++ b/src/runtime/streaming/state/error.rs @@ -1,8 +1,8 @@ // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. -use thiserror::Error; use crossbeam_channel::TrySendError; +use thiserror::Error; #[derive(Error, Debug)] pub enum StateEngineError { diff --git a/src/runtime/streaming/state/io_manager.rs b/src/runtime/streaming/state/io_manager.rs index aa85385b..9b37da1d 100644 --- a/src/runtime/streaming/state/io_manager.rs +++ b/src/runtime/streaming/state/io_manager.rs @@ -5,8 +5,8 @@ use super::error::StateEngineError; use super::metrics::StateMetricsCollector; use super::operator_state::{MemTable, OperatorStateStore, TombstoneMap}; -use crossbeam_channel::{bounded, Receiver, Sender, TrySendError}; -use std::panic::{catch_unwind, AssertUnwindSafe}; +use crossbeam_channel::{Receiver, Sender, TrySendError, bounded}; +use std::panic::{AssertUnwindSafe, catch_unwind}; use std::sync::Arc; use std::thread::{self, JoinHandle}; use std::time::Instant; @@ -63,7 +63,11 @@ impl IoPool { }; Ok(( - Self { spill_tx: Some(spill_tx), compact_tx: Some(compact_tx), worker_handles }, + Self { + spill_tx: Some(spill_tx), + compact_tx: Some(compact_tx), + worker_handles, + }, manager, )) } @@ -94,7 +98,9 @@ impl IoManager { pub fn try_send_compact(&self, job: CompactJob) -> Result<(), TrySendError> { self.compact_tx.try_send(job) } - pub fn pending_spills(&self) -> usize { self.spill_tx.len() } + pub fn pending_spills(&self) -> usize { + self.spill_tx.len() + } } fn spill_worker_loop(rx: Receiver, metrics: Arc) { @@ -104,7 +110,8 @@ fn spill_worker_loop(rx: Receiver, metrics: Arc Arc { - Arc::new(Self { current_usage: AtomicUsize::new(0), hard_limit, soft_limit }) + Arc::new(Self { + current_usage: AtomicUsize::new(0), + hard_limit, + soft_limit, + }) } pub fn exceeds_hard_limit(&self, incoming: usize) -> bool { self.current_usage.load(Ordering::Relaxed) + incoming > self.hard_limit @@ -69,7 +73,7 @@ pub struct OperatorStateStore { mem_ctrl: Arc, io_manager: IoManager, - + data_dir: PathBuf, tombstone_dir: PathBuf, @@ -88,7 +92,7 @@ impl OperatorStateStore { let op_dir = base_dir.as_ref().join(format!("op_{operator_id}")); let data_dir = op_dir.join("data"); let tombstone_dir = op_dir.join("tombstones"); - + fs::create_dir_all(&data_dir).map_err(StateEngineError::IoError)?; fs::create_dir_all(&tombstone_dir).map_err(StateEngineError::IoError)?; @@ -118,7 +122,11 @@ impl OperatorStateStore { } self.mem_ctrl.record_inc(size); - self.active_table.write().entry(key).or_default().push(batch); + self.active_table + .write() + .entry(key) + .or_default() + .push(batch); if self.mem_ctrl.should_spill() { self.downgrade_active_table(self.current_epoch.load(Ordering::Acquire)); @@ -130,7 +138,7 @@ impl OperatorStateStore { pub fn remove_batches(&self, key: PartitionKey) -> Result<()> { let current_ep = self.current_epoch.load(Ordering::Acquire); let tombstone_mem_size = key.len() + TOMBSTONE_ENTRY_OVERHEAD; - + { let mut tb_guard = self.tombstones.write(); if tb_guard.insert(key.clone(), current_ep).is_none() { @@ -157,13 +165,16 @@ impl OperatorStateStore { pub fn snapshot_epoch(self: &Arc, epoch: u64) -> Result<()> { self.downgrade_active_table(epoch); self.trigger_spill(); - self.current_epoch.store(epoch.saturating_add(1), Ordering::Release); + self.current_epoch + .store(epoch.saturating_add(1), Ordering::Release); Ok(()) } fn downgrade_active_table(&self, epoch: u64) { let mut active_guard = self.active_table.write(); - if active_guard.is_empty() { return; } + if active_guard.is_empty() { + return; + } let old_active = std::mem::take(&mut *active_guard); self.immutable_tables.lock().push_back((epoch, old_active)); } @@ -178,7 +189,9 @@ impl OperatorStateStore { for (table_epoch, table) in self.immutable_tables.lock().iter().rev() { if let Some(del_ep) = deleted_epoch { - if *table_epoch <= del_ep { continue; } + if *table_epoch <= del_ep { + continue; + } } if let Some(batches) = table.get(key) { out.extend(batches.clone()); @@ -186,7 +199,9 @@ impl OperatorStateStore { } let paths: Vec = self.data_files.read().clone(); - if paths.is_empty() { return Ok(out); } + if paths.is_empty() { + return Ok(out); + } let pk = key.to_vec(); let merged = tokio::task::spawn_blocking(move || -> Result> { @@ -194,7 +209,9 @@ impl OperatorStateStore { for path in paths { let file_epoch = extract_epoch(&path); if let Some(del_ep) = deleted_epoch { - if file_epoch <= del_ep { continue; } + if file_epoch <= del_ep { + continue; + } } // Native Bloom Filter intercepts empty reads here @@ -207,7 +224,9 @@ impl OperatorStateStore { } } Ok(acc) - }).await.map_err(|_| StateEngineError::Corruption("Tokio task panicked".into()))??; + }) + .await + .map_err(|_| StateEngineError::Corruption("Tokio task panicked".into()))??; out.extend(merged); Ok(out) @@ -223,7 +242,12 @@ impl OperatorStateStore { }; let tombstone_snapshot = self.tombstones.read().clone(); - let job = SpillJob { store: self.clone(), epoch, data, tombstone_snapshot }; + let job = SpillJob { + store: self.clone(), + epoch, + data, + tombstone_snapshot, + }; match self.io_manager.try_send_spill(job) { Ok(()) => {} @@ -238,13 +262,17 @@ impl OperatorStateStore { pub fn trigger_minor_compaction(self: &Arc) { if !self.is_compacting.swap(true, Ordering::SeqCst) { - let _ = self.io_manager.try_send_compact(CompactJob::Minor { store: self.clone() }); + let _ = self.io_manager.try_send_compact(CompactJob::Minor { + store: self.clone(), + }); } } pub fn trigger_major_compaction(self: &Arc) { if !self.is_compacting.swap(true, Ordering::SeqCst) { - let _ = self.io_manager.try_send_compact(CompactJob::Major { store: self.clone() }); + let _ = self.io_manager.try_send_compact(CompactJob::Major { + store: self.clone(), + }); } } @@ -268,7 +296,9 @@ impl OperatorStateStore { if !batches_to_write.is_empty() { let path = self.data_dir.join(Self::generate_data_file_name(epoch)); - if let Err(e) = write_parquet_with_bloom_atomic(&path, &batches_to_write, distinct_keys_count) { + if let Err(e) = + write_parquet_with_bloom_atomic(&path, &batches_to_write, distinct_keys_count) + { metrics.inc_io_errors(self.operator_id); let restored = restore_memtable_from_injected_batches(batches_to_write)?; self.immutable_tables.lock().push_front((epoch, restored)); @@ -294,10 +324,16 @@ impl OperatorStateStore { Field::new("deleted_epoch", DataType::UInt64, false), ])); let batch = RecordBatch::try_new( - schema, vec![Arc::new(key_builder.finish()), Arc::new(epoch_builder.finish())] + schema, + vec![ + Arc::new(key_builder.finish()), + Arc::new(epoch_builder.finish()), + ], )?; - let path = self.tombstone_dir.join(Self::generate_tombstone_file_name(epoch)); + let path = self + .tombstone_dir + .join(Self::generate_tombstone_file_name(epoch)); if let Err(e) = write_parquet_with_bloom_atomic(&path, &[batch], tomb_ndv) { metrics.inc_io_errors(self.operator_id); return Err(e); @@ -307,7 +343,7 @@ impl OperatorStateStore { self.mem_ctrl.record_dec(size_to_release); metrics.record_memory_usage(self.operator_id, self.mem_ctrl.usage_bytes()); - + self.is_spilling.store(false, Ordering::SeqCst); self.spill_notify.notify_waiters(); @@ -318,31 +354,45 @@ impl OperatorStateStore { } pub(crate) fn execute_compact_sync( - self: &Arc, + self: &Arc, is_major: bool, - metrics: &Arc + metrics: &Arc, ) -> Result<()> { let result = (|| -> Result<()> { let files_to_merge = { let df = self.data_files.read(); - if df.len() < 2 { return Ok(()); } - if is_major { df.clone() } else { df.iter().take(2).cloned().collect() } + if df.len() < 2 { + return Ok(()); + } + if is_major { + df.clone() + } else { + df.iter().take(2).cloned().collect() + } }; let tombstone_snapshot = self.tombstones.read().clone(); - let compacted_watermark_epoch = files_to_merge.iter().map(|p| extract_epoch(p)).max().unwrap_or(0); - let new_path = self.data_dir.join(Self::generate_data_file_name(compacted_watermark_epoch)); + let compacted_watermark_epoch = files_to_merge + .iter() + .map(|p| extract_epoch(p)) + .max() + .unwrap_or(0); + let new_path = self + .data_dir + .join(Self::generate_data_file_name(compacted_watermark_epoch)); let mut all_batches = Vec::new(); let mut estimated_rows = 0; - + for path in &files_to_merge { let file_epoch = extract_epoch(path); let file = File::open(path).map_err(StateEngineError::IoError)?; let mut reader = ParquetRecordBatchReaderBuilder::try_new(file)?.build()?; while let Some(batch) = reader.next() { let b = batch?; - if let Some(filtered) = filter_tombstones_from_batch(&b, &tombstone_snapshot, file_epoch)? { + if let Some(filtered) = + filter_tombstones_from_batch(&b, &tombstone_snapshot, file_epoch)? + { estimated_rows += filtered.num_rows() as u64; all_batches.push(filtered); } @@ -350,7 +400,11 @@ impl OperatorStateStore { } if !all_batches.is_empty() { - if let Err(e) = write_parquet_with_bloom_atomic(&new_path, &all_batches, estimated_rows.max(100)) { + if let Err(e) = write_parquet_with_bloom_atomic( + &new_path, + &all_batches, + estimated_rows.max(100), + ) { metrics.inc_io_errors(self.operator_id); return Err(e); } @@ -362,7 +416,9 @@ impl OperatorStateStore { df.retain(|p| !files_to_merge.contains(p)); } - for path in &files_to_merge { let _ = fs::remove_file(path); } + for path in &files_to_merge { + let _ = fs::remove_file(path); + } // Watermark GC { @@ -373,7 +429,9 @@ impl OperatorStateStore { if *deleted_epoch <= compacted_watermark_epoch { memory_freed += key.len() + TOMBSTONE_ENTRY_OVERHEAD; false - } else { true } + } else { + true + } }); if memory_freed > 0 { @@ -402,14 +460,18 @@ impl OperatorStateStore { pub async fn restore_metadata(&self, safe_epoch: u64) -> Result> { self.active_table.write().clear(); - self.immutable_tables.lock().retain(|(e, _)| *e <= safe_epoch); + self.immutable_tables + .lock() + .retain(|(e, _)| *e <= safe_epoch); let cleanup_future = |files: &mut Vec| { files.retain(|path| { if extract_epoch(path) > safe_epoch { let _ = fs::remove_file(path); false - } else { true } + } else { + true + } }); }; cleanup_future(&mut self.data_files.write()); @@ -423,19 +485,31 @@ impl OperatorStateStore { let mut reader = ParquetRecordBatchReaderBuilder::try_new(file)?.build()?; while let Some(batch) = reader.next() { let batch = batch?; - let key_col = batch.column(0).as_any().downcast_ref::().unwrap(); - let ep_col = batch.column(1).as_any().downcast_ref::().unwrap(); + let key_col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let ep_col = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); for i in 0..key_col.len() { let k = key_col.value(i).to_vec(); let e = ep_col.value(i); let current_max = map.get(&k).copied().unwrap_or(0); - if e > current_max { map.insert(k, e); } + if e > current_max { + map.insert(k, e); + } } } } Ok(map) - }).await.map_err(|_| StateEngineError::Corruption("Task Panicked".into()))??; + }) + .await + .map_err(|_| StateEngineError::Corruption("Task Panicked".into()))??; let mut total_tombstone_mem = 0; for key in loaded_tombstones.keys() { @@ -457,19 +531,27 @@ impl OperatorStateStore { while let Some(batch) = reader.next() { let batch = batch?; - let key_col = batch.column(0).as_any().downcast_ref::().unwrap(); + let key_col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); for i in 0..key_col.len() { let k = key_col.value(i).to_vec(); let is_active = match loaded_tombstones.get(&k) { Some(del_ep) => *del_ep < file_epoch, None => true, }; - if is_active { keys.insert(k); } + if is_active { + keys.insert(k); + } } } } Ok(keys) - }).await.map_err(|_| StateEngineError::Corruption("Task Panicked".into()))??; + }) + .await + .map_err(|_| StateEngineError::Corruption("Task Panicked".into()))??; self.current_epoch.store(safe_epoch + 1, Ordering::Release); Ok(active_keys) @@ -478,7 +560,7 @@ impl OperatorStateStore { // ======================================================================== // UUID-based file name generators // ======================================================================== - + fn generate_data_file_name(epoch: u64) -> String { format!("data-epoch-{}_uuid-{}.parquet", epoch, Uuid::now_v7()) } @@ -493,7 +575,9 @@ impl OperatorStateStore { // ============================================================================ fn write_parquet_with_bloom_atomic(path: &Path, batches: &[RecordBatch], ndv: u64) -> Result<()> { - if batches.is_empty() { return Ok(()); } + if batches.is_empty() { + return Ok(()); + } let tmp = path.with_extension("tmp"); { let file = File::create(&tmp).map_err(StateEngineError::IoError)?; @@ -503,7 +587,9 @@ fn write_parquet_with_bloom_atomic(path: &Path, batches: &[RecordBatch], ndv: u6 .build(); let mut writer = ArrowWriter::try_new(&file, batches[0].schema(), Some(props))?; - for b in batches { writer.write(b)?; } + for b in batches { + writer.write(b)?; + } writer.close()?; file.sync_all().map_err(StateEngineError::IoError)?; } @@ -512,7 +598,11 @@ fn write_parquet_with_bloom_atomic(path: &Path, batches: &[RecordBatch], ndv: u6 } fn extract_epoch(path: &Path) -> u64 { - let name = path.file_name().unwrap_or_default().to_str().unwrap_or_default(); + let name = path + .file_name() + .unwrap_or_default() + .to_str() + .unwrap_or_default(); if let Some(start) = name.find("-epoch-") { let after = &name[start + 7..]; if let Some(end) = after.find("_uuid-") { @@ -524,9 +614,16 @@ fn extract_epoch(path: &Path) -> u64 { fn inject_partition_key(batch: &RecordBatch, key: &[u8]) -> Result { let mut fields = batch.schema().fields().to_vec(); - fields.push(Arc::new(Field::new(PARTITION_KEY_COL, DataType::Binary, false))); + fields.push(Arc::new(Field::new( + PARTITION_KEY_COL, + DataType::Binary, + false, + ))); let schema = Arc::new(Schema::new(fields)); - let key_array = Arc::new(BinaryArray::from_iter_values(std::iter::repeat_n(key, batch.num_rows()))); + let key_array = Arc::new(BinaryArray::from_iter_values(std::iter::repeat_n( + key, + batch.num_rows(), + ))); let mut cols = batch.columns().to_vec(); cols.push(key_array as Arc); Ok(RecordBatch::try_new(schema, cols)?) @@ -537,10 +634,18 @@ fn filter_tombstones_from_batch( tombstones: &TombstoneMap, file_epoch: u64, ) -> Result> { - if tombstones.is_empty() { return Ok(Some(batch.clone())); } - let Ok(idx) = batch.schema().index_of(PARTITION_KEY_COL) else { return Ok(Some(batch.clone())); }; - - let key_col = batch.column(idx).as_any().downcast_ref::().unwrap(); + if tombstones.is_empty() { + return Ok(Some(batch.clone())); + } + let Ok(idx) = batch.schema().index_of(PARTITION_KEY_COL) else { + return Ok(Some(batch.clone())); + }; + + let key_col = batch + .column(idx) + .as_any() + .downcast_ref::() + .unwrap(); let mut mask_builder = BooleanBuilder::with_capacity(batch.num_rows()); let mut has_valid = false; @@ -551,22 +656,39 @@ fn filter_tombstones_from_batch( None => true, }; mask_builder.append_value(keep); - if keep { has_valid = true; } + if keep { + has_valid = true; + } } - if !has_valid { return Ok(None); } + if !has_valid { + return Ok(None); + } let mask = mask_builder.finish(); Ok(Some(arrow::compute::filter_record_batch(batch, &mask)?)) } -fn filter_and_strip_partition_key(batch: &RecordBatch, target_key: &[u8]) -> Result> { - let Ok(idx) = batch.schema().index_of(PARTITION_KEY_COL) else { return Ok(Some(batch.clone())); }; - let key_col = batch.column(idx).as_any().downcast_ref::().unwrap(); +fn filter_and_strip_partition_key( + batch: &RecordBatch, + target_key: &[u8], +) -> Result> { + let Ok(idx) = batch.schema().index_of(PARTITION_KEY_COL) else { + return Ok(Some(batch.clone())); + }; + let key_col = batch + .column(idx) + .as_any() + .downcast_ref::() + .unwrap(); let mut mask_builder = BooleanBuilder::with_capacity(batch.num_rows()); - for i in 0..batch.num_rows() { mask_builder.append_value(key_col.value(i) == target_key); } + for i in 0..batch.num_rows() { + mask_builder.append_value(key_col.value(i) == target_key); + } let mask = mask_builder.finish(); let filtered = arrow::compute::filter_record_batch(batch, &mask)?; - if filtered.num_rows() == 0 { return Ok(None); } + if filtered.num_rows() == 0 { + return Ok(None); + } let mut proj: Vec = (0..filtered.num_columns()).collect(); proj.retain(|&i| i != idx); Ok(Some(filtered.project(&proj)?)) @@ -576,7 +698,11 @@ fn restore_memtable_from_injected_batches(batches: Vec) -> Result().unwrap(); + let key_col = batch + .column(idx) + .as_any() + .downcast_ref::() + .unwrap(); let pk = key_col.value(0).to_vec(); let mut proj: Vec = (0..batch.num_columns()).collect(); proj.retain(|&i| i != idx); diff --git a/src/storage/stream_catalog/manager.rs b/src/storage/stream_catalog/manager.rs index 9691e991..471e3cd9 100644 --- a/src/storage/stream_catalog/manager.rs +++ b/src/storage/stream_catalog/manager.rs @@ -17,7 +17,7 @@ use datafusion::common::{Result as DFResult, internal_err, plan_err}; use prost::Message; use protocol::function_stream_graph::FsProgram; use protocol::storage::{self as pb, table_definition}; -use tracing::{info, warn, debug}; +use tracing::{debug, info, warn}; use unicase::UniCase; use crate::sql::common::constants::sql_field; @@ -144,9 +144,7 @@ impl CatalogManager { } /// Returns (table_name, program, checkpoint_interval_ms, latest_checkpoint_epoch). - pub fn load_streaming_job_definitions( - &self, - ) -> DFResult> { + pub fn load_streaming_job_definitions(&self) -> DFResult> { let records = self.store.scan_prefix(STREAMING_JOB_KEY_PREFIX)?; let mut out = Vec::with_capacity(records.len()); for (key, payload) in records { From d25fb45be5dbb8dae71c33c8e928af0acf3cbe5c Mon Sep 17 00:00:00 2001 From: luoluoyuyu Date: Mon, 13 Apr 2026 00:55:17 +0800 Subject: [PATCH 03/26] update --- .../grouping/incremental_aggregate.rs | 328 +++++++++++++++--- .../operators/grouping/updating_cache.rs | 4 + .../windows/session_aggregating_window.rs | 189 +++++++++- .../windows/sliding_aggregating_window.rs | 182 +++++++++- .../windows/tumbling_aggregating_window.rs | 128 ++++++- .../operators/windows/window_function.rs | 207 ++++++----- src/sql/analysis/aggregate_rewriter.rs | 37 +- 7 files changed, 890 insertions(+), 185 deletions(-) diff --git a/src/runtime/streaming/operators/grouping/incremental_aggregate.rs b/src/runtime/streaming/operators/grouping/incremental_aggregate.rs index 625cdee5..fd58c4be 100644 --- a/src/runtime/streaming/operators/grouping/incremental_aggregate.rs +++ b/src/runtime/streaming/operators/grouping/incremental_aggregate.rs @@ -11,15 +11,17 @@ // limitations under the License. use crate::sql::common::constants::updating_state_field; -use anyhow::{Result, bail}; -use arrow::compute::max_array; +use anyhow::{Result, anyhow, bail}; +use arrow::compute::{concat_batches, max_array}; use arrow::row::{RowConverter, SortField}; use arrow_array::builder::{ BinaryBuilder, TimestampNanosecondBuilder, UInt32Builder, UInt64Builder, }; use arrow_array::cast::AsArray; use arrow_array::types::UInt64Type; -use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch, StructArray}; +use arrow_array::{ + Array, ArrayRef, BinaryArray, BooleanArray, RecordBatch, StructArray, UInt32Array, UInt64Array, +}; use arrow_schema::{DataType, Field, FieldRef, Schema, SchemaBuilder, TimeUnit}; use datafusion::common::{Result as DFResult, ScalarValue}; use datafusion::physical_expr::aggregate::AggregateFunctionExpr; @@ -36,7 +38,7 @@ use std::collections::HashSet; use std::sync::LazyLock; use std::time::{Duration, Instant, SystemTime}; use std::{collections::HashMap, mem, sync::Arc}; -use tracing::{debug, warn}; +use tracing::{debug, info, warn}; // ========================================================================= // ========================================================================= use crate::runtime::streaming::StreamOutput; @@ -44,6 +46,7 @@ use crate::runtime::streaming::api::context::TaskContext; use crate::runtime::streaming::api::operator::Operator; use crate::runtime::streaming::factory::Registry; use crate::runtime::streaming::operators::{Key, UpdatingCache}; +use crate::runtime::streaming::state::OperatorStateStore; use crate::runtime::util::decode_aggregate; use crate::sql::common::{ CheckpointBarrier, FsSchema, TIMESTAMP_FIELD, UPDATING_META_FIELD, Watermark, to_nanos, @@ -213,10 +216,15 @@ pub struct IncrementalAggregatingFunc { ttl: Duration, key_converter: RowConverter, new_generation: u64, + + state_store: Option>, } static GLOBAL_KEY: LazyLock>> = LazyLock::new(|| Arc::new(Vec::new())); +const KEY_SLIDING_SNAPSHOT: &[u8] = &[0x01]; +const KEY_BATCH_SNAPSHOT: &[u8] = &[0x02]; + impl IncrementalAggregatingFunc { fn update_batch( &mut self, @@ -437,40 +445,39 @@ impl IncrementalAggregatingFunc { // ========================================================================= fn checkpoint_sliding(&mut self) -> DFResult>> { - if self.updated_keys.is_empty() { + let keys = self.accumulators.keys(); + if keys.is_empty() { return Ok(None); } let mut states = vec![vec![]; self.sliding_state_schema.schema.fields.len()]; let parser = self.key_converter.parser(); - let mut generation_builder = UInt64Builder::with_capacity(self.updated_keys.len()); - - let mut cols = self - .key_converter - .convert_rows(self.updated_keys.keys().map(|k| { - let (accumulators, generation) = - self.accumulators.get_mut_generation(k.0.as_ref()).unwrap(); - generation_builder.append_value(generation); - - for (state, agg) in accumulators.iter_mut().zip(self.aggregates.iter()) { - let IncrementalState::Sliding { expr, accumulator } = state else { - continue; - }; - let state = accumulator.state().unwrap_or_else(|_| { - let state = accumulator.state().unwrap(); - *accumulator = expr.create_sliding_accumulator().unwrap(); - let states: Vec<_> = - state.iter().map(|s| s.to_array()).try_collect().unwrap(); - accumulator.merge_batch(&states).unwrap(); - state - }); - - for (idx, v) in agg.state_cols.iter().zip(state.into_iter()) { - states[*idx].push(v); - } + let mut generation_builder = UInt64Builder::with_capacity(keys.len()); + + let mut cols = self.key_converter.convert_rows(keys.iter().map(|k| { + let (accumulators, generation) = + self.accumulators.get_mut_generation(k.0.as_ref()).unwrap(); + generation_builder.append_value(generation); + + for (state, agg) in accumulators.iter_mut().zip(self.aggregates.iter()) { + let IncrementalState::Sliding { expr, accumulator } = state else { + continue; + }; + let state = accumulator.state().unwrap_or_else(|_| { + let state = accumulator.state().unwrap(); + *accumulator = expr.create_sliding_accumulator().unwrap(); + let states: Vec<_> = + state.iter().map(|s| s.to_array()).try_collect().unwrap(); + accumulator.merge_batch(&states).unwrap(); + state + }); + + for (idx, v) in agg.state_cols.iter().zip(state.into_iter()) { + states[*idx].push(v); } - parser.parse(k.0.as_ref()) - }))?; + } + parser.parse(k.0.as_ref()) + }))?; cols.extend( states @@ -482,7 +489,7 @@ impl IncrementalAggregatingFunc { let generations = generation_builder.finish(); self.new_generation = self .new_generation - .max(max_array::(&generations).unwrap()); + .max(max_array::(&generations).unwrap_or(0)); cols.push(Arc::new(generations)); Ok(Some(cols)) @@ -496,12 +503,22 @@ impl IncrementalAggregatingFunc { { return Ok(None); } - if self.updated_keys.is_empty() { + + let keys = self.accumulators.keys(); + + let mut size = 0; + for k in &keys { + for state in self.accumulators.get_mut(k.0.as_ref()).unwrap().iter_mut() { + if let IncrementalState::Batch { data, .. } = state { + size += data.len(); + } + } + } + if size == 0 { return Ok(None); } - let size = self.updated_keys.len(); - let mut rows = Vec::with_capacity(size); + let mut key_bytes_for_rows = Vec::with_capacity(size); let mut accumulator_builder = UInt32Builder::with_capacity(size); let mut args_row_builder = BinaryBuilder::with_capacity(size, size * 4); let mut count_builder = UInt64Builder::with_capacity(size); @@ -509,10 +526,8 @@ impl IncrementalAggregatingFunc { let mut generation_builder = UInt64Builder::with_capacity(size); let now = to_nanos(SystemTime::now()) as i64; - let parser = self.key_converter.parser(); - for k in self.updated_keys.keys() { - let row = parser.parse(&k.0); + for k in keys { for (i, state) in self .accumulators .get_mut(k.0.as_ref()) @@ -520,29 +535,27 @@ impl IncrementalAggregatingFunc { .iter_mut() .enumerate() { - let IncrementalState::Batch { - data, - changed_values, - .. - } = state - else { + let IncrementalState::Batch { data, .. } = state else { continue; }; - for vk in changed_values.iter() { - if let Some(count) = data.get(vk) { - accumulator_builder.append_value(i as u32); - args_row_builder.append_value(&*vk.0); - count_builder.append_value(count.count); - generation_builder.append_value(count.generation); - timestamp_builder.append_value(now); - rows.push(row.to_owned()) - } + for (vk, count_data) in data.iter() { + accumulator_builder.append_value(i as u32); + args_row_builder.append_value(&*vk.0); + count_builder.append_value(count_data.count); + generation_builder.append_value(count_data.generation); + timestamp_builder.append_value(now); + key_bytes_for_rows.push(k.0.clone()); } data.retain(|_, v| v.count > 0); } } + let parser = self.key_converter.parser(); + let rows: Vec<_> = key_bytes_for_rows + .iter() + .map(|kb| parser.parse(kb).to_owned()) + .collect(); let mut cols = self.key_converter.convert_rows(rows.into_iter())?; cols.push(Arc::new(accumulator_builder.finish())); cols.push(Arc::new(args_row_builder.finish())); @@ -552,7 +565,7 @@ impl IncrementalAggregatingFunc { let generations = generation_builder.finish(); self.new_generation = self .new_generation - .max(max_array::(&generations).unwrap()); + .max(max_array::(&generations).unwrap_or(0)); cols.push(Arc::new(generations)); Ok(Some(cols)) @@ -710,7 +723,167 @@ impl Operator for IncrementalAggregatingFunc { } async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { + let store = OperatorStateStore::new( + ctx.pipeline_id, + ctx.state_dir.clone(), + ctx.memory_controller.clone(), + ctx.io_manager.clone(), + ) + .map_err(|e| anyhow!("Failed to init state store: {e}"))?; + + let safe_epoch = ctx.latest_safe_epoch(); + let active_keys = store + .restore_metadata(safe_epoch) + .await + .map_err(|e| anyhow!("State recovery failed: {e}"))?; + + if !active_keys.is_empty() { + info!( + pipeline_id = ctx.pipeline_id, + key_count = active_keys.len(), + "Updating Aggregate recovering state from LSM-Tree..." + ); + + let mut sliding_batches = Vec::new(); + let mut batch_batches = Vec::new(); + + for key in active_keys { + if key == KEY_SLIDING_SNAPSHOT { + sliding_batches.extend( + store + .get_batches(&key) + .await + .map_err(|e| anyhow!("{e}"))?, + ); + } else if key == KEY_BATCH_SNAPSHOT { + batch_batches.extend( + store + .get_batches(&key) + .await + .map_err(|e| anyhow!("{e}"))?, + ); + } + } + + let num_keys = self + .input_schema + .routing_keys() + .map(|k| k.len()) + .unwrap_or(0); + let now = Instant::now(); + + // Restore sliding (reversible) accumulator state + if !sliding_batches.is_empty() { + let combined = + concat_batches(&self.sliding_state_schema.schema, &sliding_batches)?; + let key_cols: Vec = combined.columns()[0..num_keys].to_vec(); + let aggregate_states: Vec> = self + .aggregates + .iter() + .map(|agg| { + agg.state_cols + .iter() + .map(|&idx| combined.column(idx).clone()) + .collect() + }) + .collect(); + let gen_col = combined + .column(combined.num_columns() - 1) + .as_any() + .downcast_ref::() + .expect("generation column must be UInt64Array"); + + let rows = self.key_converter.convert_columns(&key_cols)?; + for i in 0..combined.num_rows() { + let key = rows.row(i).as_ref().to_vec(); + let generation = gen_col.value(i); + self.restore_sliding(&key, now, i, &aggregate_states, generation)?; + } + info!( + rows = combined.num_rows(), + "Restored sliding accumulator state." + ); + } + + // Restore batch (non-reversible) detail dictionaries + if !batch_batches.is_empty() { + let combined = + concat_batches(&self.batch_state_schema.schema, &batch_batches)?; + let key_cols: Vec = combined.columns()[0..num_keys].to_vec(); + + let acc_idx_col = combined + .column(num_keys) + .as_any() + .downcast_ref::() + .expect("accumulator index column must be UInt32Array"); + let args_col = combined + .column(num_keys + 1) + .as_any() + .downcast_ref::() + .expect("args_row column must be BinaryArray"); + let count_col = combined + .column(num_keys + 2) + .as_any() + .downcast_ref::() + .expect("count column must be UInt64Array"); + // column num_keys+3 is timestamp, skip + let gen_col = combined + .column(num_keys + 4) + .as_any() + .downcast_ref::() + .expect("generation column must be UInt64Array"); + + let rows = self.key_converter.convert_columns(&key_cols)?; + + for i in 0..combined.num_rows() { + let key = rows.row(i).as_ref().to_vec(); + let acc_idx = acc_idx_col.value(i) as usize; + let args_row = args_col.value(i).to_vec(); + let count = count_col.value(i); + let generation = gen_col.value(i); + + if !self.accumulators.contains_key(&key) { + self.accumulators.insert( + Arc::new(key.clone()), + now, + generation, + self.make_accumulators(), + ); + } + + if let Some(accs) = self.accumulators.get_mut(&key) { + if let Some(IncrementalState::Batch { + data, + changed_values, + .. + }) = accs.get_mut(acc_idx) + { + let vk = Key(Arc::new(args_row.clone())); + data.insert( + vk.clone(), + BatchData { + count, + generation, + }, + ); + changed_values.insert(vk); + } + } + } + info!( + rows = combined.num_rows(), + "Restored batch detail state." + ); + } + + info!( + groups = self.accumulators.keys().len(), + "Updating Aggregate successfully restored active groups." + ); + } + self.initialize(ctx).await?; + self.state_store = Some(store); Ok(()) } @@ -743,9 +916,51 @@ impl Operator for IncrementalAggregatingFunc { async fn snapshot_state( &mut self, - _barrier: CheckpointBarrier, + barrier: CheckpointBarrier, _ctx: &mut TaskContext, ) -> Result<()> { + let store = self + .state_store + .clone() + .expect("State store not initialized"); + + // Tombstone previous epoch snapshots for disk space reclamation + store + .remove_batches(KEY_SLIDING_SNAPSHOT.to_vec()) + .map_err(|e| anyhow!("{e}"))?; + store + .remove_batches(KEY_BATCH_SNAPSHOT.to_vec()) + .map_err(|e| anyhow!("{e}"))?; + + // Full snapshot of sliding (reversible) accumulator state + if let Some(cols) = self.checkpoint_sliding()? { + let batch = + RecordBatch::try_new(self.sliding_state_schema.schema.clone(), cols)?; + store + .put(KEY_SLIDING_SNAPSHOT.to_vec(), batch) + .await + .map_err(|e| anyhow!("{e}"))?; + } + + // Full snapshot of batch (non-reversible) detail state + if let Some(cols) = self.checkpoint_batch()? { + let batch = + RecordBatch::try_new(self.batch_state_schema.schema.clone(), cols)?; + store + .put(KEY_BATCH_SNAPSHOT.to_vec(), batch) + .await + .map_err(|e| anyhow!("{e}"))?; + } + + // Flush to Parquet + store + .snapshot_epoch(barrier.epoch as u64) + .map_err(|e| anyhow!("Snapshot failed: {e}"))?; + + info!(epoch = barrier.epoch, "Updating Aggregate snapshotted successfully."); + + self.updated_keys.clear(); + Ok(()) } @@ -907,6 +1122,7 @@ impl IncrementalAggregatingConstructor { sliding_state_schema, batch_state_schema, new_generation: 0, + state_store: None, }) } } diff --git a/src/runtime/streaming/operators/grouping/updating_cache.rs b/src/runtime/streaming/operators/grouping/updating_cache.rs index 37f2ba04..34c732fc 100644 --- a/src/runtime/streaming/operators/grouping/updating_cache.rs +++ b/src/runtime/streaming/operators/grouping/updating_cache.rs @@ -64,6 +64,10 @@ impl Iterator for TTLIter<'_, T> { } impl UpdatingCache { + pub fn keys(&self) -> Vec { + self.map.keys().cloned().collect() + } + pub fn with_time_to_idle(ttl: Duration) -> Self { Self { map: HashMap::new(), diff --git a/src/runtime/streaming/operators/windows/session_aggregating_window.rs b/src/runtime/streaming/operators/windows/session_aggregating_window.rs index 4293ea7c..f4a8708e 100644 --- a/src/runtime/streaming/operators/windows/session_aggregating_window.rs +++ b/src/runtime/streaming/operators/windows/session_aggregating_window.rs @@ -30,15 +30,17 @@ use datafusion_proto::physical_plan::AsExecutionPlan; use datafusion_proto::protobuf::PhysicalPlanNode; use futures::StreamExt; use prost::Message; -use std::collections::{BTreeMap, HashMap, HashSet}; +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use std::sync::{Arc, RwLock}; use std::time::{Duration, SystemTime}; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel}; +use tracing::info; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; use crate::runtime::streaming::api::operator::Operator; use crate::runtime::streaming::factory::Registry; +use crate::runtime::streaming::state::OperatorStateStore; use crate::sql::common::converter::Converter; use crate::sql::common::{ CheckpointBarrier, FsSchema, FsSchemaRef, Watermark, from_nanos, to_nanos, @@ -170,6 +172,7 @@ impl ActiveSession { } } +#[derive(Clone)] struct SessionWindowResult { window_start: SystemTime, window_end: SystemTime, @@ -389,9 +392,39 @@ pub struct SessionWindowOperator { session_states: HashMap, KeySessionState>, pq_watermark_actions: BTreeMap>>, pq_start_times: BTreeMap>>, + + // LSM-Tree state engine and per-routing-key timestamp index + state_store: Option>, + pending_timestamps: HashMap, BTreeSet>, } impl SessionWindowOperator { + // State key: [RoutingKey bytes] + [8-byte big-endian timestamp] + fn build_state_key(routing_key: &[u8], ts_nanos: u64) -> Vec { + let mut key = Vec::with_capacity(routing_key.len() + 8); + key.extend_from_slice(routing_key); + key.extend_from_slice(&ts_nanos.to_be_bytes()); + key + } + + fn extract_timestamp(key: &[u8]) -> Option { + if key.len() >= 8 { + let mut ts_bytes = [0u8; 8]; + ts_bytes.copy_from_slice(&key[key.len() - 8..]); + Some(u64::from_be_bytes(ts_bytes)) + } else { + None + } + } + + fn extract_routing_key(key: &[u8]) -> Vec { + if key.len() >= 8 { + key[..key.len() - 8].to_vec() + } else { + Vec::new() + } + } + fn filter_batch_by_time( &self, batch: RecordBatch, @@ -430,6 +463,7 @@ impl SessionWindowOperator { &mut self, sorted_batch: RecordBatch, watermark: Option, + is_recovery_replay: bool, ) -> Result<()> { let partition_ranges = if !self.config.input_schema_ref.has_routing_keys() { std::iter::once(0..sorted_batch.num_rows()).collect::>() @@ -470,6 +504,32 @@ impl SessionWindowOperator { .to_vec() }; + // Write-ahead persistence: skip during recovery replay to avoid duplicate writes + if !is_recovery_replay { + let ts_col = key_batch + .column(self.config.input_schema_ref.timestamp_index) + .as_any() + .downcast_ref::() + .unwrap(); + let ts_nanos = ts_col.value(0) as u64; + + let state_key = Self::build_state_key(&row_key, ts_nanos); + let store = self + .state_store + .as_ref() + .expect("State store not initialized"); + + store + .put(state_key, key_batch.clone()) + .await + .map_err(|e| anyhow!("{e}"))?; + + self.pending_timestamps + .entry(row_key.clone()) + .or_default() + .insert(ts_nanos); + } + let state = self .session_states .entry(row_key.clone()) @@ -529,7 +589,10 @@ impl SessionWindowOperator { Ok(()) } - async fn evaluate_watermark(&mut self, watermark: SystemTime) -> Result> { + async fn evaluate_watermark_with_meta( + &mut self, + watermark: SystemTime, + ) -> Result, Vec)>> { let mut emit_results: Vec<(Vec, Vec)> = Vec::new(); loop { @@ -588,11 +651,7 @@ impl SessionWindowOperator { } } - if emit_results.is_empty() { - return Ok(vec![]); - } - - Ok(vec![self.format_to_arrow(emit_results)?]) + Ok(emit_results) } fn format_to_arrow( @@ -666,10 +725,71 @@ impl Operator for SessionWindowOperator { "SessionWindow" } - async fn on_start(&mut self, _ctx: &mut TaskContext) -> Result<()> { + // Recovery & event sourcing: rebuild in-memory sessions from LSM-Tree + async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { + let store = OperatorStateStore::new( + ctx.pipeline_id, + ctx.state_dir.clone(), + ctx.memory_controller.clone(), + ctx.io_manager.clone(), + ) + .map_err(|e| anyhow!("Failed to init state store: {e}"))?; + + let safe_epoch = ctx.latest_safe_epoch(); + let active_keys = store + .restore_metadata(safe_epoch) + .await + .map_err(|e| anyhow!("State recovery failed: {e}"))?; + + if !active_keys.is_empty() { + info!( + pipeline_id = ctx.pipeline_id, + key_count = active_keys.len(), + "Session Operator recovering active state keys from LSM-Tree..." + ); + + let mut recovered_batches = Vec::new(); + + for key in active_keys { + if let Some(ts) = Self::extract_timestamp(&key) { + let row_key = Self::extract_routing_key(&key); + self.pending_timestamps + .entry(row_key) + .or_default() + .insert(ts); + } + + let batches = store + .get_batches(&key) + .await + .map_err(|e| anyhow!("{e}"))?; + recovered_batches.extend(batches); + } + + // Temporal ordering is critical: replay must preserve watermark/session merge invariants + recovered_batches.sort_by_key(|b| { + b.column(self.config.input_schema_ref.timestamp_index) + .as_any() + .downcast_ref::() + .map(|ts| ts.value(0)) + .unwrap_or(0) + }); + + for batch in recovered_batches { + self.ingest_sorted_batch(batch, None, true).await?; + } + + info!( + pipeline_id = ctx.pipeline_id, + "Session Window Operator successfully replayed events and rebuilt in-memory sessions." + ); + } + + self.state_store = Some(store); Ok(()) } + // Write-ahead: persist raw data before in-memory ingestion async fn process_data( &mut self, _input_idx: usize, @@ -685,12 +805,13 @@ impl Operator for SessionWindowOperator { let sorted_batch = self.sort_batch(&filtered_batch)?; - self.ingest_sorted_batch(sorted_batch, watermark_time) + self.ingest_sorted_batch(sorted_batch, watermark_time, false) .await?; Ok(vec![]) } + // Watermark-driven session closure with precise LSM-Tree garbage collection async fn process_watermark( &mut self, watermark: Watermark, @@ -700,18 +821,54 @@ impl Operator for SessionWindowOperator { return Ok(vec![]); }; - let output_batches = self.evaluate_watermark(current_time).await?; - Ok(output_batches - .into_iter() - .map(StreamOutput::Forward) - .collect()) + let completed_sessions = self.evaluate_watermark_with_meta(current_time).await?; + if completed_sessions.is_empty() { + return Ok(vec![]); + } + + let store = self + .state_store + .as_ref() + .expect("State store not initialized"); + + // GC: tombstone expired raw data covered by closed sessions + for (row_key, session_results) in &completed_sessions { + if let Some(ts_set) = self.pending_timestamps.get_mut(row_key) { + for session_res in session_results { + let start_nanos = to_nanos(session_res.window_start) as u64; + let end_nanos = + to_nanos(session_res.window_end - self.config.gap) as u64; + + let expired_ts: Vec = + ts_set.range(start_nanos..=end_nanos).copied().collect(); + + for ts in expired_ts { + let state_key = Self::build_state_key(row_key, ts); + store + .remove_batches(state_key) + .map_err(|e| anyhow!("{e}"))?; + ts_set.remove(&ts); + } + } + } + } + + let output_batch = self.format_to_arrow(completed_sessions)?; + Ok(vec![StreamOutput::Forward(output_batch)]) } async fn snapshot_state( &mut self, - _barrier: CheckpointBarrier, + barrier: CheckpointBarrier, _ctx: &mut TaskContext, ) -> Result<()> { + self.state_store + .as_ref() + .expect("State store not initialized") + .snapshot_epoch(barrier.epoch as u64) + .map_err(|e| anyhow!("Snapshot failed: {e}"))?; + + info!(epoch = barrier.epoch, "Session Window Operator snapshotted state."); Ok(()) } @@ -797,6 +954,8 @@ impl SessionAggregatingWindowConstructor { pq_start_times: BTreeMap::new(), pq_watermark_actions: BTreeMap::new(), row_converter, + state_store: None, + pending_timestamps: HashMap::new(), }) } } diff --git a/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs b/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs index 73ba4dc9..294a035f 100644 --- a/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs +++ b/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs @@ -11,7 +11,7 @@ // limitations under the License. use anyhow::{Result, anyhow, bail}; -use arrow::compute::{partition, sort_to_indices, take}; +use arrow::compute::{concat_batches, partition, sort_to_indices, take}; use arrow_array::{Array, PrimitiveArray, RecordBatch, types::TimestampNanosecondType}; use arrow_schema::SchemaRef; use datafusion::common::ScalarValue; @@ -27,20 +27,49 @@ use datafusion_proto::{ }; use futures::StreamExt; use prost::Message; -use std::collections::{BTreeMap, VecDeque}; +use std::collections::{BTreeMap, BTreeSet, VecDeque}; use std::sync::{Arc, RwLock}; use std::time::{Duration, SystemTime}; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel}; +use tracing::info; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; use crate::runtime::streaming::api::operator::Operator; use crate::runtime::streaming::factory::Registry; +use crate::runtime::streaming::state::OperatorStateStore; use crate::sql::common::{CheckpointBarrier, FsSchema, Watermark, from_nanos, to_nanos}; use crate::sql::physical::{StreamingDecodingContext, StreamingExtensionCodec}; use async_trait::async_trait; use protocol::function_stream_graph::SlidingWindowAggregateOperator; // ============================================================================ +// Dual-layer state key: [StateType(1B)] + [Timestamp(8B BE)] +// STATE_TYPE_RAW = 0 (raw input data, pending partial aggregation) +// STATE_TYPE_PARTIAL = 1 (pre-aggregated pane results) +// ============================================================================ + +const STATE_TYPE_RAW: u8 = 0; +const STATE_TYPE_PARTIAL: u8 = 1; + +fn build_state_key(state_type: u8, ts_nanos: u64) -> Vec { + let mut key = Vec::with_capacity(9); + key.push(state_type); + key.extend_from_slice(&ts_nanos.to_be_bytes()); + key +} + +fn parse_state_key(key: &[u8]) -> Option<(u8, u64)> { + if key.len() == 9 { + let mut ts_bytes = [0u8; 8]; + ts_bytes.copy_from_slice(&key[1..9]); + Some((key[0], u64::from_be_bytes(ts_bytes))) + } else { + None + } +} + +// ============================================================================ +// RecordBatchTier & TieredRecordBatchHolder // ============================================================================ #[derive(Default, Debug)] @@ -263,6 +292,11 @@ pub struct SlidingWindowOperator { active_bins: BTreeMap, tiered_record_batches: TieredRecordBatchHolder, + + // LSM-Tree state engine with dual-layer index + state_store: Option>, + pending_raw_bins: BTreeSet, + pending_partial_bins: BTreeSet, } impl SlidingWindowOperator { @@ -309,10 +343,80 @@ impl Operator for SlidingWindowOperator { "SlidingWindow" } - async fn on_start(&mut self, _ctx: &mut TaskContext) -> Result<()> { + // Recovery: restore dual-layer state (partial panes + raw active bins) + async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { + let store = OperatorStateStore::new( + ctx.pipeline_id, + ctx.state_dir.clone(), + ctx.memory_controller.clone(), + ctx.io_manager.clone(), + ) + .map_err(|e| anyhow!("Failed to init state store: {e}"))?; + + let safe_epoch = ctx.latest_safe_epoch(); + let active_keys = store + .restore_metadata(safe_epoch) + .await + .map_err(|e| anyhow!("State recovery failed: {e}"))?; + + let mut raw_recovery_batches = Vec::new(); + + for key in active_keys { + if let Some((state_type, ts_nanos)) = parse_state_key(&key) { + let batches = store + .get_batches(&key) + .await + .map_err(|e| anyhow!("{e}"))?; + if batches.is_empty() { + continue; + } + + if state_type == STATE_TYPE_PARTIAL { + let bin_start = from_nanos(ts_nanos as u128); + for b in batches { + self.tiered_record_batches.insert(b, bin_start)?; + } + self.pending_partial_bins.insert(ts_nanos); + } else if state_type == STATE_TYPE_RAW { + let schema = batches[0].schema(); + let combined = concat_batches(&schema, &batches)?; + raw_recovery_batches.push((ts_nanos, combined)); + } + } + } + + // Temporal ordering guarantees correct DataFusion session replay + raw_recovery_batches.sort_by_key(|(ts, _)| *ts); + + for (ts_nanos, batch) in raw_recovery_batches { + let bin_start = from_nanos(ts_nanos as u128); + let slot = self.active_bins.entry(bin_start).or_default(); + Self::ensure_bin_running( + slot, + self.partial_aggregation_plan.clone(), + &self.receiver_hook, + )?; + + slot.sender + .as_ref() + .unwrap() + .send(batch) + .map_err(|e| anyhow!("{e}"))?; + self.pending_raw_bins.insert(ts_nanos); + } + + info!( + pipeline_id = ctx.pipeline_id, + partial_bins = self.pending_partial_bins.len(), + raw_bins = self.pending_raw_bins.len(), + "Sliding Window Operator recovered state." + ); + + self.state_store = Some(store); Ok(()) } + // Write-ahead: persist raw data (Type 0) before in-memory computation async fn process_data( &mut self, _input_idx: usize, @@ -340,6 +444,7 @@ impl Operator for SlidingWindowOperator { let partition_ranges = partition(std::slice::from_ref(&sorted_bins))?.ranges(); let watermark = ctx.current_watermark(); + let store = self.state_store.clone().expect("State store not initialized"); for range in partition_ranges { let bin_start = from_nanos(typed_bin.value(range.start) as u128); @@ -351,8 +456,16 @@ impl Operator for SlidingWindowOperator { } let bin_batch = sorted.slice(range.start, range.end - range.start); - let slot = self.active_bins.entry(bin_start).or_default(); + let bin_start_nanos = to_nanos(bin_start) as u64; + + let key = build_state_key(STATE_TYPE_RAW, bin_start_nanos); + store + .put(key, bin_batch.clone()) + .await + .map_err(|e| anyhow!("{e}"))?; + self.pending_raw_bins.insert(bin_start_nanos); + let slot = self.active_bins.entry(bin_start).or_default(); Self::ensure_bin_running( slot, self.partial_aggregation_plan.clone(), @@ -371,6 +484,7 @@ impl Operator for SlidingWindowOperator { Ok(vec![]) } + // State morphing (Type 0 → Type 1) and dual-layer GC async fn process_watermark( &mut self, watermark: Watermark, @@ -380,6 +494,7 @@ impl Operator for SlidingWindowOperator { return Ok(vec![]); }; let watermark_bin = self.bin_start(current_time); + let store = self.state_store.clone().expect("State store not initialized"); let mut final_outputs = Vec::new(); @@ -398,12 +513,36 @@ impl Operator for SlidingWindowOperator { .remove(&bin_start) .ok_or_else(|| anyhow!("missing active bin"))?; let bin_end = bin_start + self.slide; + let bin_start_nanos = to_nanos(bin_start) as u64; + // Phase 1: drain partial aggregation from DataFusion bin.close_and_drain().await?; - for b in bin.finished_batches { - self.tiered_record_batches.insert(b, bin_start)?; + + // Phase 2: state morphing — persist partial result (Type 1), feed tiered holder + if !bin.finished_batches.is_empty() { + let schema = bin.finished_batches[0].schema(); + let combined_partial = concat_batches(&schema, &bin.finished_batches)?; + + let p_key = build_state_key(STATE_TYPE_PARTIAL, bin_start_nanos); + store + .put(p_key, combined_partial) + .await + .map_err(|e| anyhow!("{e}"))?; + self.pending_partial_bins.insert(bin_start_nanos); + + for b in bin.finished_batches { + self.tiered_record_batches.insert(b, bin_start)?; + } } + // Phase 3: tombstone raw data (Type 0) — no longer needed after partial is saved + let r_key = build_state_key(STATE_TYPE_RAW, bin_start_nanos); + store + .remove_batches(r_key) + .map_err(|e| anyhow!("{e}"))?; + self.pending_raw_bins.remove(&bin_start_nanos); + + // Phase 4: compute final sliding window result let interval_start = bin_end - self.width; let interval_end = bin_end; @@ -436,8 +575,25 @@ impl Operator for SlidingWindowOperator { final_outputs.push(StreamOutput::Forward(batch?)); } - self.tiered_record_batches - .delete_before(bin_end + self.slide - self.width)?; + // Phase 5: GC expired partial bins (Type 1) that fall outside the window + let cutoff_time = bin_end + self.slide - self.width; + self.tiered_record_batches.delete_before(cutoff_time)?; + + let cutoff_nanos = to_nanos(cutoff_time) as u64; + let expired_partials: Vec = self + .pending_partial_bins + .iter() + .take_while(|&&ts| ts < cutoff_nanos) + .copied() + .collect(); + + for ts in expired_partials { + let p_key = build_state_key(STATE_TYPE_PARTIAL, ts); + store + .remove_batches(p_key) + .map_err(|e| anyhow!("{e}"))?; + self.pending_partial_bins.remove(&ts); + } } Ok(final_outputs) @@ -445,9 +601,14 @@ impl Operator for SlidingWindowOperator { async fn snapshot_state( &mut self, - _barrier: CheckpointBarrier, + barrier: CheckpointBarrier, _ctx: &mut TaskContext, ) -> Result<()> { + self.state_store + .as_ref() + .expect("State store not initialized") + .snapshot_epoch(barrier.epoch as u64) + .map_err(|e| anyhow!("Snapshot failed: {e}"))?; Ok(()) } @@ -531,6 +692,9 @@ impl SlidingAggregatingWindowConstructor { final_batches_passer, active_bins: BTreeMap::new(), tiered_record_batches: TieredRecordBatchHolder::new(vec![slide])?, + state_store: None, + pending_raw_bins: BTreeSet::new(), + pending_partial_bins: BTreeSet::new(), }) } } diff --git a/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs b/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs index de576bf0..0a6f0af9 100644 --- a/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs +++ b/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs @@ -27,17 +27,18 @@ use datafusion_proto::{ }; use futures::StreamExt; use prost::Message; -use std::collections::BTreeMap; +use std::collections::{BTreeMap, BTreeSet}; use std::mem; use std::sync::{Arc, RwLock}; use std::time::{Duration, SystemTime}; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel}; -use tracing::warn; +use tracing::{info, warn}; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; use crate::runtime::streaming::api::operator::Operator; use crate::runtime::streaming::factory::Registry; +use crate::runtime::streaming::state::OperatorStateStore; use crate::sql::common::time_utils::print_time; use crate::sql::common::{CheckpointBarrier, FsSchema, Watermark, from_nanos, to_nanos}; use crate::sql::physical::{StreamingDecodingContext, StreamingExtensionCodec}; @@ -94,9 +95,28 @@ pub struct TumblingWindowOperator { final_batches_passer: Arc>>, active_bins: BTreeMap, + + // LSM-Tree state engine and pending window timestamp index + state_store: Option>, + pending_bins: BTreeSet, } impl TumblingWindowOperator { + // State key: 8-byte big-endian bin_start_nanos + fn build_state_key(ts_nanos: u64) -> Vec { + ts_nanos.to_be_bytes().to_vec() + } + + fn extract_timestamp(key: &[u8]) -> Option { + if key.len() == 8 { + let mut ts_bytes = [0u8; 8]; + ts_bytes.copy_from_slice(key); + Some(u64::from_be_bytes(ts_bytes)) + } else { + None + } + } + fn bin_start(&self, timestamp: SystemTime) -> SystemTime { if self.width == Duration::ZERO { return timestamp; @@ -141,10 +161,70 @@ impl Operator for TumblingWindowOperator { "TumblingWindow" } - async fn on_start(&mut self, _ctx: &mut TaskContext) -> Result<()> { + // Recovery: replay raw data from LSM-Tree into DataFusion sessions + async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { + let store = OperatorStateStore::new( + ctx.pipeline_id, + ctx.state_dir.clone(), + ctx.memory_controller.clone(), + ctx.io_manager.clone(), + ) + .map_err(|e| anyhow!("Failed to init state store: {e}"))?; + + let safe_epoch = ctx.latest_safe_epoch(); + let active_keys = store + .restore_metadata(safe_epoch) + .await + .map_err(|e| anyhow!("State recovery failed: {e}"))?; + + if !active_keys.is_empty() { + info!( + pipeline_id = ctx.pipeline_id, + key_count = active_keys.len(), + "Tumbling Window Operator recovering active windows from LSM-Tree..." + ); + + for key in active_keys { + if let Some(ts_nanos) = Self::extract_timestamp(&key) { + let bin_start = from_nanos(ts_nanos as u128); + + let batches = store + .get_batches(&key) + .await + .map_err(|e| anyhow!("{e}"))?; + if batches.is_empty() { + continue; + } + + let slot = self.active_bins.entry(bin_start).or_default(); + Self::ensure_bin_running( + slot, + self.partial_aggregation_plan.clone(), + &self.receiver_hook, + )?; + + let sender = slot.sender.as_ref().unwrap(); + for batch in batches { + sender + .send(batch) + .map_err(|e| anyhow!("recovery channel send: {e}"))?; + } + + self.pending_bins.insert(ts_nanos); + } + } + + info!( + pipeline_id = ctx.pipeline_id, + "Tumbling Window Operator successfully replayed events and rebuilt in-memory state." + ); + } + + self.state_store = Some(store); Ok(()) } + // Write-ahead: persist raw data before in-memory computation async fn process_data( &mut self, _input_idx: usize, @@ -171,6 +251,11 @@ impl Operator for TumblingWindowOperator { .ok_or_else(|| anyhow!("binning function must produce TimestampNanosecond"))?; let partition_ranges = partition(std::slice::from_ref(&sorted_bins))?.ranges(); + let store = self + .state_store + .as_ref() + .expect("State store not initialized"); + for range in partition_ranges { let bin_start = from_nanos(typed_bin.value(range.start) as u128); @@ -186,8 +271,16 @@ impl Operator for TumblingWindowOperator { } let bin_batch = sorted.slice(range.start, range.end - range.start); - let slot = self.active_bins.entry(bin_start).or_default(); + let bin_start_nanos = to_nanos(bin_start) as u64; + + let state_key = Self::build_state_key(bin_start_nanos); + store + .put(state_key, bin_batch.clone()) + .await + .map_err(|e| anyhow!("{e}"))?; + self.pending_bins.insert(bin_start_nanos); + let slot = self.active_bins.entry(bin_start).or_default(); Self::ensure_bin_running( slot, self.partial_aggregation_plan.clone(), @@ -206,6 +299,7 @@ impl Operator for TumblingWindowOperator { Ok(vec![]) } + // Watermark-driven window closure with LSM-Tree GC async fn process_watermark( &mut self, watermark: Watermark, @@ -214,6 +308,10 @@ impl Operator for TumblingWindowOperator { let Watermark::EventTime(current_time) = watermark else { return Ok(vec![]); }; + let store = self + .state_store + .as_ref() + .expect("State store not initialized"); let mut final_outputs = Vec::new(); @@ -227,10 +325,8 @@ impl Operator for TumblingWindowOperator { } for bin_start in expired_bins { - let mut bin = self - .active_bins - .remove(&bin_start) - .ok_or_else(|| anyhow!("missing tumbling bin"))?; + let mut bin = self.active_bins.remove(&bin_start).unwrap(); + let bin_start_nanos = to_nanos(bin_start) as u64; bin.close_and_drain().await?; let partial_batches = mem::take(&mut bin.finished_batches); @@ -271,6 +367,13 @@ impl Operator for TumblingWindowOperator { final_outputs.push(StreamOutput::Forward(batch?)); } } + + // Tombstone the raw data — window is fully closed + let state_key = Self::build_state_key(bin_start_nanos); + store + .remove_batches(state_key) + .map_err(|e| anyhow!("{e}"))?; + self.pending_bins.remove(&bin_start_nanos); } Ok(final_outputs) @@ -278,9 +381,14 @@ impl Operator for TumblingWindowOperator { async fn snapshot_state( &mut self, - _barrier: CheckpointBarrier, + barrier: CheckpointBarrier, _ctx: &mut TaskContext, ) -> Result<()> { + self.state_store + .as_ref() + .expect("State store not initialized") + .snapshot_epoch(barrier.epoch as u64) + .map_err(|e| anyhow!("Snapshot failed: {e}"))?; Ok(()) } @@ -367,6 +475,8 @@ impl TumblingAggregateWindowConstructor { receiver_hook, final_batches_passer, active_bins: BTreeMap::new(), + state_store: None, + pending_bins: BTreeSet::new(), }) } } diff --git a/src/runtime/streaming/operators/windows/window_function.rs b/src/runtime/streaming/operators/windows/window_function.rs index 5e340fec..66349ea2 100644 --- a/src/runtime/streaming/operators/windows/window_function.rs +++ b/src/runtime/streaming/operators/windows/window_function.rs @@ -13,7 +13,6 @@ use anyhow::{Result, anyhow}; use arrow::compute::{max, min}; use arrow_array::RecordBatch; -use datafusion::execution::SendableRecordBatchStream; use datafusion::execution::context::SessionContext; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::physical_plan::ExecutionPlan; @@ -21,57 +20,26 @@ use datafusion_proto::physical_plan::AsExecutionPlan; use datafusion_proto::protobuf::PhysicalPlanNode; use futures::StreamExt; use prost::Message; -use std::collections::BTreeMap; +use std::collections::BTreeSet; use std::sync::{Arc, RwLock}; use std::time::SystemTime; -use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel}; -use tracing::warn; +use tokio::sync::mpsc::{UnboundedReceiver, unbounded_channel}; +use tracing::{info, warn}; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; use crate::runtime::streaming::api::operator::Operator; use crate::runtime::streaming::factory::Registry; +use crate::runtime::streaming::state::OperatorStateStore; use crate::sql::common::time_utils::print_time; -use crate::sql::common::{CheckpointBarrier, FsSchema, FsSchemaRef, Watermark, from_nanos}; +use crate::sql::common::{ + CheckpointBarrier, FsSchema, FsSchemaRef, Watermark, from_nanos, to_nanos, +}; use crate::sql::physical::{StreamingDecodingContext, StreamingExtensionCodec}; use async_trait::async_trait; // ============================================================================ -// ============================================================================ - -struct ActiveWindowExec { - sender: Option>, - result_stream: Option, -} - -impl ActiveWindowExec { - fn new( - plan: Arc, - hook: &Arc>>>, - ) -> Result { - let (tx, rx) = unbounded_channel(); - *hook.write().unwrap() = Some(rx); - plan.reset()?; - let result_stream = plan.execute(0, SessionContext::new().task_ctx())?; - Ok(Self { - sender: Some(tx), - result_stream: Some(result_stream), - }) - } - - async fn close_and_drain(&mut self) -> Result> { - self.sender.take(); - let mut results = Vec::new(); - if let Some(mut stream) = self.result_stream.take() { - while let Some(batch) = stream.next().await { - results.push(batch?); - } - } - Ok(results) - } -} - -// ============================================================================ +// WindowFunctionOperator: LSM-Tree backed lazy-compute model // ============================================================================ pub struct WindowFunctionOperator { @@ -79,10 +47,28 @@ pub struct WindowFunctionOperator { input_schema_unkeyed: FsSchemaRef, window_exec_plan: Arc, receiver_hook: Arc>>>, - active_execs: BTreeMap, + + // LSM-Tree state engine and lightweight timestamp index + state_store: Option>, + pending_timestamps: BTreeSet, } impl WindowFunctionOperator { + // State key: 8-byte big-endian timestamp (nanos) + fn build_state_key(ts_nanos: u64) -> Vec { + ts_nanos.to_be_bytes().to_vec() + } + + fn extract_timestamp(key: &[u8]) -> Option { + if key.len() == 8 { + let mut ts_bytes = [0u8; 8]; + ts_bytes.copy_from_slice(key); + Some(u64::from_be_bytes(ts_bytes)) + } else { + None + } + } + fn filter_and_split_batches( &self, batch: RecordBatch, @@ -137,18 +123,6 @@ impl WindowFunctionOperator { } Ok(batches) } - - fn get_or_create_exec(&mut self, timestamp: SystemTime) -> Result<&mut ActiveWindowExec> { - use std::collections::btree_map::Entry; - match self.active_execs.entry(timestamp) { - Entry::Vacant(v) => { - let new_exec = - ActiveWindowExec::new(self.window_exec_plan.clone(), &self.receiver_hook)?; - Ok(v.insert(new_exec)) - } - Entry::Occupied(o) => Ok(o.into_mut()), - } - } } #[async_trait] @@ -157,10 +131,47 @@ impl Operator for WindowFunctionOperator { "WindowFunction" } - async fn on_start(&mut self, _ctx: &mut TaskContext) -> Result<()> { + // Recovery: restore the lightweight timestamp index from LSM-Tree. + // Data stays on disk until process_watermark triggers on-demand compute. + async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { + let store = OperatorStateStore::new( + ctx.pipeline_id, + ctx.state_dir.clone(), + ctx.memory_controller.clone(), + ctx.io_manager.clone(), + ) + .map_err(|e| anyhow!("Failed to init state store: {e}"))?; + + let safe_epoch = ctx.latest_safe_epoch(); + let active_keys = store + .restore_metadata(safe_epoch) + .await + .map_err(|e| anyhow!("State recovery failed: {e}"))?; + + if !active_keys.is_empty() { + info!( + pipeline_id = ctx.pipeline_id, + key_count = active_keys.len(), + "Window Function Operator recovering active timestamps from LSM-Tree..." + ); + + for key in active_keys { + if let Some(ts_nanos) = Self::extract_timestamp(&key) { + self.pending_timestamps.insert(ts_nanos); + } + } + + info!( + pipeline_id = ctx.pipeline_id, + "Window Function Operator successfully rebuilt in-memory indices." + ); + } + + self.state_store = Some(store); Ok(()) } + // Write-ahead: persist data into LSM-Tree, defer computation to watermark async fn process_data( &mut self, _input_idx: usize, @@ -169,19 +180,27 @@ impl Operator for WindowFunctionOperator { ) -> Result> { let current_watermark = ctx.current_watermark(); let split_batches = self.filter_and_split_batches(batch, current_watermark)?; + let store = self + .state_store + .as_ref() + .expect("State store not initialized"); for (sub_batch, timestamp) in split_batches { - let exec = self.get_or_create_exec(timestamp)?; - exec.sender - .as_ref() - .ok_or_else(|| anyhow!("window exec sender missing"))? - .send(sub_batch) - .map_err(|e| anyhow!("route batch to plan: {e}"))?; + let ts_nanos = to_nanos(timestamp) as u64; + let key = Self::build_state_key(ts_nanos); + + store + .put(key, sub_batch) + .await + .map_err(|e| anyhow!("{e}"))?; + + self.pending_timestamps.insert(ts_nanos); } Ok(vec![]) } + // On-demand compute & GC: pull data from LSM-Tree, run DataFusion, tombstone async fn process_watermark( &mut self, watermark: Watermark, @@ -190,27 +209,53 @@ impl Operator for WindowFunctionOperator { let Watermark::EventTime(current_time) = watermark else { return Ok(vec![]); }; + let store = self + .state_store + .as_ref() + .expect("State store not initialized"); + let current_nanos = to_nanos(current_time) as u64; + + let expired_ts: Vec = self + .pending_timestamps + .iter() + .take_while(|&&ts| ts < current_nanos) + .copied() + .collect(); let mut final_outputs = Vec::new(); - let mut expired_timestamps = Vec::new(); - for &k in self.active_execs.keys() { - if k < current_time { - expired_timestamps.push(k); - } else { - break; - } - } + for ts in expired_ts { + let key = Self::build_state_key(ts); + + let batches = store + .get_batches(&key) + .await + .map_err(|e| anyhow!("{e}"))?; - for ts in expired_timestamps { - let mut exec = self - .active_execs - .remove(&ts) - .ok_or_else(|| anyhow!("missing window exec"))?; - let result_batches = exec.close_and_drain().await?; - for batch in result_batches { - final_outputs.push(StreamOutput::Forward(batch)); + if !batches.is_empty() { + let (tx, rx) = unbounded_channel(); + *self.receiver_hook.write().unwrap() = Some(rx); + + self.window_exec_plan.reset()?; + let mut stream = self + .window_exec_plan + .execute(0, SessionContext::new().task_ctx())?; + + for batch in batches { + tx.send(batch) + .map_err(|e| anyhow!("Failed to send batch to execution plan: {e}"))?; + } + drop(tx); + + while let Some(res) = stream.next().await { + final_outputs.push(StreamOutput::Forward(res?)); + } } + + store + .remove_batches(key) + .map_err(|e| anyhow!("{e}"))?; + self.pending_timestamps.remove(&ts); } Ok(final_outputs) @@ -218,9 +263,14 @@ impl Operator for WindowFunctionOperator { async fn snapshot_state( &mut self, - _barrier: CheckpointBarrier, + barrier: CheckpointBarrier, _ctx: &mut TaskContext, ) -> Result<()> { + self.state_store + .as_ref() + .expect("State store not initialized") + .snapshot_epoch(barrier.epoch as u64) + .map_err(|e| anyhow!("Snapshot failed: {e}"))?; Ok(()) } @@ -275,7 +325,8 @@ impl WindowFunctionConstructor { input_schema_unkeyed, window_exec_plan, receiver_hook, - active_execs: BTreeMap::new(), + state_store: None, + pending_timestamps: BTreeSet::new(), }) } } diff --git a/src/sql/analysis/aggregate_rewriter.rs b/src/sql/analysis/aggregate_rewriter.rs index d7be0db8..ddcb0294 100644 --- a/src/sql/analysis/aggregate_rewriter.rs +++ b/src/sql/analysis/aggregate_rewriter.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use crate::sql::analysis::streaming_window_analzer::StreamingWindowAnalzer; use crate::sql::logical_node::aggregate::StreamWindowAggregateNode; use crate::sql::logical_node::key_calculation::{KeyExtractionNode, KeyExtractionStrategy}; +use crate::sql::logical_node::updating_aggregate::ContinuousAggregateNode; use crate::sql::schema::StreamSchemaProvider; use crate::sql::types::{ QualifiedField, TIMESTAMP_FIELD, WindowBehavior, WindowType, build_df_schema_with_metadata, @@ -70,10 +71,10 @@ impl TreeNodeRewriter for AggregateRewriter<'_> { }) .collect(); - // 3. Dispatch to Updating Aggregate if no windowing is detected. + // 3. Dispatch to ContinuousAggregateNode (UpdatingAggregate) if no windowing is detected. let input_window = StreamingWindowAnalzer::get_window(&agg.input)?; if window_exprs.is_empty() && input_window.is_none() { - return self.rewrite_as_updating_aggregate( + return self.rewrite_as_continuous_updating_aggregate( agg.input, key_fields, agg.group_expr, @@ -174,9 +175,9 @@ impl<'a> AggregateRewriter<'a> { })) } - /// [Strategy] Rewrites standard GROUP BY into a non-windowed updating aggregate. + /// [Strategy] Rewrites standard GROUP BY into a ContinuousAggregateNode with retraction semantics. /// Injected max(_timestamp) ensures the streaming pulse (Watermark) continues to propagate. - fn rewrite_as_updating_aggregate( + fn rewrite_as_continuous_updating_aggregate( &self, input: Arc, key_fields: Vec, @@ -184,6 +185,7 @@ impl<'a> AggregateRewriter<'a> { mut aggr_expr: Vec, schema: Arc, ) -> Result> { + let key_count = key_fields.len(); let keyed_input = self.build_keyed_input(input, &group_expr, &key_fields)?; // Ensure the updating stream maintains time awareness. @@ -207,14 +209,23 @@ impl<'a> AggregateRewriter<'a> { schema.metadata().clone(), )?); - let aggregate = Aggregate::try_new_with_schema( + let base_aggregate = Aggregate::try_new_with_schema( Arc::new(keyed_input), group_expr, aggr_expr, output_schema, )?; - Ok(Transformed::yes(LogicalPlan::Aggregate(aggregate))) + let continuous_node = ContinuousAggregateNode::try_new( + LogicalPlan::Aggregate(base_aggregate), + (0..key_count).collect(), + None, + self.schema_provider.planning_options.ttl, + )?; + + Ok(Transformed::yes(LogicalPlan::Extension(Extension { + node: Arc::new(continuous_node), + }))) } /// [Strategy] Reconciles window definitions between the input stream and the current GROUP BY. @@ -232,24 +243,16 @@ impl<'a> AggregateRewriter<'a> { let has_group_window = !window_expr_info.is_empty(); match (input_window, has_group_window) { - // Re-aggregation or subquery with an existing window. (Some(i_win), true) => { let (idx, g_win) = window_expr_info.pop().unwrap(); if i_win != g_win { - return plan_err!( - "Inconsistent windowing: input is {:?}, but group by is {:?}", - i_win, - g_win - ); + return plan_err!("Inconsistent windowing detected"); } if let Some(field) = visitor.fields.iter().next() { group_expr[idx] = Expr::Column(field.qualified_column()); Ok(WindowBehavior::InData) } else { - if matches!(i_win, WindowType::Session { .. }) { - return plan_err!("Nested session windows are not supported"); - } group_expr.remove(idx); Ok(WindowBehavior::FromOperator { window: i_win, @@ -259,7 +262,6 @@ impl<'a> AggregateRewriter<'a> { }) } } - // First-time windowing defined in this aggregate. (None, true) => { let (idx, g_win) = window_expr_info.pop().unwrap(); group_expr.remove(idx); @@ -270,9 +272,8 @@ impl<'a> AggregateRewriter<'a> { is_nested: false, }) } - // Passthrough: input is already windowed, no new window in group by. (Some(_), false) => Ok(WindowBehavior::InData), - _ => unreachable!("Dispatched to non-windowed path previously"), + _ => unreachable!("Handled by updating path"), } } } From 70501bd9049f5e0af6a425505b67aa74be009157 Mon Sep 17 00:00:00 2001 From: luoluoyuyu Date: Mon, 13 Apr 2026 01:00:00 +0800 Subject: [PATCH 04/26] update --- .../grouping/incremental_aggregate.rs | 49 ++++++------------- .../windows/session_aggregating_window.rs | 13 +++-- .../windows/sliding_aggregating_window.rs | 23 +++++---- .../windows/tumbling_aggregating_window.rs | 5 +- .../operators/windows/window_function.rs | 9 +--- 5 files changed, 35 insertions(+), 64 deletions(-) diff --git a/src/runtime/streaming/operators/grouping/incremental_aggregate.rs b/src/runtime/streaming/operators/grouping/incremental_aggregate.rs index fd58c4be..2299bac4 100644 --- a/src/runtime/streaming/operators/grouping/incremental_aggregate.rs +++ b/src/runtime/streaming/operators/grouping/incremental_aggregate.rs @@ -466,8 +466,7 @@ impl IncrementalAggregatingFunc { let state = accumulator.state().unwrap_or_else(|_| { let state = accumulator.state().unwrap(); *accumulator = expr.create_sliding_accumulator().unwrap(); - let states: Vec<_> = - state.iter().map(|s| s.to_array()).try_collect().unwrap(); + let states: Vec<_> = state.iter().map(|s| s.to_array()).try_collect().unwrap(); accumulator.merge_batch(&states).unwrap(); state }); @@ -749,19 +748,11 @@ impl Operator for IncrementalAggregatingFunc { for key in active_keys { if key == KEY_SLIDING_SNAPSHOT { - sliding_batches.extend( - store - .get_batches(&key) - .await - .map_err(|e| anyhow!("{e}"))?, - ); + sliding_batches + .extend(store.get_batches(&key).await.map_err(|e| anyhow!("{e}"))?); } else if key == KEY_BATCH_SNAPSHOT { - batch_batches.extend( - store - .get_batches(&key) - .await - .map_err(|e| anyhow!("{e}"))?, - ); + batch_batches + .extend(store.get_batches(&key).await.map_err(|e| anyhow!("{e}"))?); } } @@ -774,8 +765,7 @@ impl Operator for IncrementalAggregatingFunc { // Restore sliding (reversible) accumulator state if !sliding_batches.is_empty() { - let combined = - concat_batches(&self.sliding_state_schema.schema, &sliding_batches)?; + let combined = concat_batches(&self.sliding_state_schema.schema, &sliding_batches)?; let key_cols: Vec = combined.columns()[0..num_keys].to_vec(); let aggregate_states: Vec> = self .aggregates @@ -807,8 +797,7 @@ impl Operator for IncrementalAggregatingFunc { // Restore batch (non-reversible) detail dictionaries if !batch_batches.is_empty() { - let combined = - concat_batches(&self.batch_state_schema.schema, &batch_batches)?; + let combined = concat_batches(&self.batch_state_schema.schema, &batch_batches)?; let key_cols: Vec = combined.columns()[0..num_keys].to_vec(); let acc_idx_col = combined @@ -859,21 +848,12 @@ impl Operator for IncrementalAggregatingFunc { }) = accs.get_mut(acc_idx) { let vk = Key(Arc::new(args_row.clone())); - data.insert( - vk.clone(), - BatchData { - count, - generation, - }, - ); + data.insert(vk.clone(), BatchData { count, generation }); changed_values.insert(vk); } } } - info!( - rows = combined.num_rows(), - "Restored batch detail state." - ); + info!(rows = combined.num_rows(), "Restored batch detail state."); } info!( @@ -934,8 +914,7 @@ impl Operator for IncrementalAggregatingFunc { // Full snapshot of sliding (reversible) accumulator state if let Some(cols) = self.checkpoint_sliding()? { - let batch = - RecordBatch::try_new(self.sliding_state_schema.schema.clone(), cols)?; + let batch = RecordBatch::try_new(self.sliding_state_schema.schema.clone(), cols)?; store .put(KEY_SLIDING_SNAPSHOT.to_vec(), batch) .await @@ -944,8 +923,7 @@ impl Operator for IncrementalAggregatingFunc { // Full snapshot of batch (non-reversible) detail state if let Some(cols) = self.checkpoint_batch()? { - let batch = - RecordBatch::try_new(self.batch_state_schema.schema.clone(), cols)?; + let batch = RecordBatch::try_new(self.batch_state_schema.schema.clone(), cols)?; store .put(KEY_BATCH_SNAPSHOT.to_vec(), batch) .await @@ -957,7 +935,10 @@ impl Operator for IncrementalAggregatingFunc { .snapshot_epoch(barrier.epoch as u64) .map_err(|e| anyhow!("Snapshot failed: {e}"))?; - info!(epoch = barrier.epoch, "Updating Aggregate snapshotted successfully."); + info!( + epoch = barrier.epoch, + "Updating Aggregate snapshotted successfully." + ); self.updated_keys.clear(); diff --git a/src/runtime/streaming/operators/windows/session_aggregating_window.rs b/src/runtime/streaming/operators/windows/session_aggregating_window.rs index f4a8708e..15075964 100644 --- a/src/runtime/streaming/operators/windows/session_aggregating_window.rs +++ b/src/runtime/streaming/operators/windows/session_aggregating_window.rs @@ -759,10 +759,7 @@ impl Operator for SessionWindowOperator { .insert(ts); } - let batches = store - .get_batches(&key) - .await - .map_err(|e| anyhow!("{e}"))?; + let batches = store.get_batches(&key).await.map_err(|e| anyhow!("{e}"))?; recovered_batches.extend(batches); } @@ -836,8 +833,7 @@ impl Operator for SessionWindowOperator { if let Some(ts_set) = self.pending_timestamps.get_mut(row_key) { for session_res in session_results { let start_nanos = to_nanos(session_res.window_start) as u64; - let end_nanos = - to_nanos(session_res.window_end - self.config.gap) as u64; + let end_nanos = to_nanos(session_res.window_end - self.config.gap) as u64; let expired_ts: Vec = ts_set.range(start_nanos..=end_nanos).copied().collect(); @@ -868,7 +864,10 @@ impl Operator for SessionWindowOperator { .snapshot_epoch(barrier.epoch as u64) .map_err(|e| anyhow!("Snapshot failed: {e}"))?; - info!(epoch = barrier.epoch, "Session Window Operator snapshotted state."); + info!( + epoch = barrier.epoch, + "Session Window Operator snapshotted state." + ); Ok(()) } diff --git a/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs b/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs index 294a035f..538e0dad 100644 --- a/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs +++ b/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs @@ -363,10 +363,7 @@ impl Operator for SlidingWindowOperator { for key in active_keys { if let Some((state_type, ts_nanos)) = parse_state_key(&key) { - let batches = store - .get_batches(&key) - .await - .map_err(|e| anyhow!("{e}"))?; + let batches = store.get_batches(&key).await.map_err(|e| anyhow!("{e}"))?; if batches.is_empty() { continue; } @@ -444,7 +441,10 @@ impl Operator for SlidingWindowOperator { let partition_ranges = partition(std::slice::from_ref(&sorted_bins))?.ranges(); let watermark = ctx.current_watermark(); - let store = self.state_store.clone().expect("State store not initialized"); + let store = self + .state_store + .clone() + .expect("State store not initialized"); for range in partition_ranges { let bin_start = from_nanos(typed_bin.value(range.start) as u128); @@ -494,7 +494,10 @@ impl Operator for SlidingWindowOperator { return Ok(vec![]); }; let watermark_bin = self.bin_start(current_time); - let store = self.state_store.clone().expect("State store not initialized"); + let store = self + .state_store + .clone() + .expect("State store not initialized"); let mut final_outputs = Vec::new(); @@ -537,9 +540,7 @@ impl Operator for SlidingWindowOperator { // Phase 3: tombstone raw data (Type 0) — no longer needed after partial is saved let r_key = build_state_key(STATE_TYPE_RAW, bin_start_nanos); - store - .remove_batches(r_key) - .map_err(|e| anyhow!("{e}"))?; + store.remove_batches(r_key).map_err(|e| anyhow!("{e}"))?; self.pending_raw_bins.remove(&bin_start_nanos); // Phase 4: compute final sliding window result @@ -589,9 +590,7 @@ impl Operator for SlidingWindowOperator { for ts in expired_partials { let p_key = build_state_key(STATE_TYPE_PARTIAL, ts); - store - .remove_batches(p_key) - .map_err(|e| anyhow!("{e}"))?; + store.remove_batches(p_key).map_err(|e| anyhow!("{e}"))?; self.pending_partial_bins.remove(&ts); } } diff --git a/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs b/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs index 0a6f0af9..7bf3268d 100644 --- a/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs +++ b/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs @@ -188,10 +188,7 @@ impl Operator for TumblingWindowOperator { if let Some(ts_nanos) = Self::extract_timestamp(&key) { let bin_start = from_nanos(ts_nanos as u128); - let batches = store - .get_batches(&key) - .await - .map_err(|e| anyhow!("{e}"))?; + let batches = store.get_batches(&key).await.map_err(|e| anyhow!("{e}"))?; if batches.is_empty() { continue; } diff --git a/src/runtime/streaming/operators/windows/window_function.rs b/src/runtime/streaming/operators/windows/window_function.rs index 66349ea2..cf6a198d 100644 --- a/src/runtime/streaming/operators/windows/window_function.rs +++ b/src/runtime/streaming/operators/windows/window_function.rs @@ -227,10 +227,7 @@ impl Operator for WindowFunctionOperator { for ts in expired_ts { let key = Self::build_state_key(ts); - let batches = store - .get_batches(&key) - .await - .map_err(|e| anyhow!("{e}"))?; + let batches = store.get_batches(&key).await.map_err(|e| anyhow!("{e}"))?; if !batches.is_empty() { let (tx, rx) = unbounded_channel(); @@ -252,9 +249,7 @@ impl Operator for WindowFunctionOperator { } } - store - .remove_batches(key) - .map_err(|e| anyhow!("{e}"))?; + store.remove_batches(key).map_err(|e| anyhow!("{e}"))?; self.pending_timestamps.remove(&ts); } From 649f04fa58ef82d17a0271373fcf0ab612b1f211 Mon Sep 17 00:00:00 2001 From: luoluoyuyu Date: Mon, 13 Apr 2026 01:06:46 +0800 Subject: [PATCH 05/26] update --- src/runtime/streaming/api/context.rs | 1 + src/runtime/streaming/job/job_manager.rs | 1 + .../grouping/incremental_aggregate.rs | 13 ++++----- src/runtime/streaming/state/operator_state.rs | 28 +++++++++---------- 4 files changed, 22 insertions(+), 21 deletions(-) diff --git a/src/runtime/streaming/api/context.rs b/src/runtime/streaming/api/context.rs index f7c47a8f..27babd56 100644 --- a/src/runtime/streaming/api/context.rs +++ b/src/runtime/streaming/api/context.rs @@ -78,6 +78,7 @@ pub struct TaskContext { } impl TaskContext { + #[allow(clippy::too_many_arguments)] pub fn new( job_id: String, pipeline_id: u32, diff --git a/src/runtime/streaming/job/job_manager.rs b/src/runtime/streaming/job/job_manager.rs index a7c982c4..b0839e4a 100644 --- a/src/runtime/streaming/job/job_manager.rs +++ b/src/runtime/streaming/job/job_manager.rs @@ -425,6 +425,7 @@ impl JobManager { .collect()) } + #[allow(clippy::too_many_arguments)] fn build_and_spawn_pipeline( &self, job_id: String, diff --git a/src/runtime/streaming/operators/grouping/incremental_aggregate.rs b/src/runtime/streaming/operators/grouping/incremental_aggregate.rs index 2299bac4..efe0abbb 100644 --- a/src/runtime/streaming/operators/grouping/incremental_aggregate.rs +++ b/src/runtime/streaming/operators/grouping/incremental_aggregate.rs @@ -840,17 +840,16 @@ impl Operator for IncrementalAggregatingFunc { ); } - if let Some(accs) = self.accumulators.get_mut(&key) { - if let Some(IncrementalState::Batch { + if let Some(accs) = self.accumulators.get_mut(&key) + && let Some(IncrementalState::Batch { data, changed_values, .. }) = accs.get_mut(acc_idx) - { - let vk = Key(Arc::new(args_row.clone())); - data.insert(vk.clone(), BatchData { count, generation }); - changed_values.insert(vk); - } + { + let vk = Key(Arc::new(args_row.clone())); + data.insert(vk.clone(), BatchData { count, generation }); + changed_values.insert(vk); } } info!(rows = combined.num_rows(), "Restored batch detail state."); diff --git a/src/runtime/streaming/state/operator_state.rs b/src/runtime/streaming/state/operator_state.rs index 69da3192..1eead256 100644 --- a/src/runtime/streaming/state/operator_state.rs +++ b/src/runtime/streaming/state/operator_state.rs @@ -188,10 +188,10 @@ impl OperatorStateStore { } for (table_epoch, table) in self.immutable_tables.lock().iter().rev() { - if let Some(del_ep) = deleted_epoch { - if *table_epoch <= del_ep { - continue; - } + if let Some(del_ep) = deleted_epoch + && *table_epoch <= del_ep + { + continue; } if let Some(batches) = table.get(key) { out.extend(batches.clone()); @@ -208,10 +208,10 @@ impl OperatorStateStore { let mut acc = Vec::new(); for path in paths { let file_epoch = extract_epoch(&path); - if let Some(del_ep) = deleted_epoch { - if file_epoch <= del_ep { - continue; - } + if let Some(del_ep) = deleted_epoch + && file_epoch <= del_ep + { + continue; } // Native Bloom Filter intercepts empty reads here @@ -387,8 +387,8 @@ impl OperatorStateStore { for path in &files_to_merge { let file_epoch = extract_epoch(path); let file = File::open(path).map_err(StateEngineError::IoError)?; - let mut reader = ParquetRecordBatchReaderBuilder::try_new(file)?.build()?; - while let Some(batch) = reader.next() { + let reader = ParquetRecordBatchReaderBuilder::try_new(file)?.build()?; + for batch in reader { let b = batch?; if let Some(filtered) = filter_tombstones_from_batch(&b, &tombstone_snapshot, file_epoch)? @@ -482,8 +482,8 @@ impl OperatorStateStore { let mut map = HashMap::new(); for path in tomb_paths { let file = File::open(&path).map_err(StateEngineError::IoError)?; - let mut reader = ParquetRecordBatchReaderBuilder::try_new(file)?.build()?; - while let Some(batch) = reader.next() { + let reader = ParquetRecordBatchReaderBuilder::try_new(file)?.build()?; + for batch in reader { let batch = batch?; let key_col = batch .column(0) @@ -527,9 +527,9 @@ impl OperatorStateStore { let builder = ParquetRecordBatchReaderBuilder::try_new(file)?; let schema = builder.parquet_schema(); let mask = ProjectionMask::leaves(schema, vec![schema.columns().len() - 1]); - let mut reader = builder.with_projection(mask).build()?; + let reader = builder.with_projection(mask).build()?; - while let Some(batch) = reader.next() { + for batch in reader { let batch = batch?; let key_col = batch .column(0) From 34d05dde0df18f2fc942e7977f87b517daa39a68 Mon Sep 17 00:00:00 2001 From: luoluoyuyu Date: Mon, 13 Apr 2026 21:19:40 +0800 Subject: [PATCH 06/26] update --- .github/workflows/verify-package.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/verify-package.yml b/.github/workflows/verify-package.yml index 085ac61d..10ad364a 100644 --- a/.github/workflows/verify-package.yml +++ b/.github/workflows/verify-package.yml @@ -52,7 +52,9 @@ jobs: libcurl4-openssl-dev \ pkg-config \ libsasl2-dev \ - protobuf-compiler + protobuf-compiler \ + musl-tools + sudo ln -sf /usr/bin/musl-gcc /usr/local/bin/x86_64-linux-musl-gcc - name: Cache Cargo uses: Swatinem/rust-cache@v2 From 029a56a43bfc7b87950c7b2247847cbfe5e25607 Mon Sep 17 00:00:00 2001 From: luoluoyuyu Date: Mon, 13 Apr 2026 21:34:04 +0800 Subject: [PATCH 07/26] update --- Cargo.lock | 1 + Cargo.toml | 3 + src/runtime/streaming/state/operator_state.rs | 291 ++++++++++++++++++ 3 files changed, 295 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 33f0cbf3..b8edca1b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2229,6 +2229,7 @@ dependencies = [ "serde_yaml", "sqlparser", "strum", + "tempfile", "thiserror 2.0.18", "tokio", "tokio-stream", diff --git a/Cargo.toml b/Cargo.toml index fb380ff1..531601d0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,3 +79,6 @@ governor = "0.8.0" default = ["incremental-cache", "python"] incremental-cache = ["wasmtime/incremental-cache"] python = [] + +[dev-dependencies] +tempfile = "3.27.0" diff --git a/src/runtime/streaming/state/operator_state.rs b/src/runtime/streaming/state/operator_state.rs index 1eead256..86cb3729 100644 --- a/src/runtime/streaming/state/operator_state.rs +++ b/src/runtime/streaming/state/operator_state.rs @@ -710,3 +710,294 @@ fn restore_memtable_from_injected_batches(batches: Vec) -> Result Arc { + Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Int64, + false, + )])) + } + + fn make_batch(values: &[i64]) -> RecordBatch { + RecordBatch::try_new( + test_schema(), + vec![Arc::new(Int64Array::from(values.to_vec()))], + ) + .unwrap() + } + + fn setup() -> (TempDir, Arc, IoManager, IoPool) { + let tmp = TempDir::new().unwrap(); + let mem = MemoryController::new(1024 * 1024, 2 * 1024 * 1024); + let metrics: Arc = Arc::new(NoopMetricsCollector); + let (pool, mgr) = IoPool::try_new(1, 1, metrics).unwrap(); + (tmp, mem, mgr, pool) + } + + #[tokio::test] + async fn test_put_and_get() { + let (tmp, mem, mgr, _pool) = setup(); + let store = OperatorStateStore::new(1, tmp.path(), mem, mgr).unwrap(); + + let key = b"key-a".to_vec(); + let batch = make_batch(&[10, 20, 30]); + store.put(key.clone(), batch).await.unwrap(); + + let result = store.get_batches(&key).await.unwrap(); + assert_eq!(result.len(), 1); + let col = result[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col.values(), &[10, 20, 30]); + } + + #[tokio::test] + async fn test_multiple_puts_same_key() { + let (tmp, mem, mgr, _pool) = setup(); + let store = OperatorStateStore::new(1, tmp.path(), mem, mgr).unwrap(); + + let key = b"key-x".to_vec(); + store.put(key.clone(), make_batch(&[1])).await.unwrap(); + store.put(key.clone(), make_batch(&[2])).await.unwrap(); + + let result = store.get_batches(&key).await.unwrap(); + assert_eq!(result.len(), 2); + } + + #[tokio::test] + async fn test_get_nonexistent_key() { + let (tmp, mem, mgr, _pool) = setup(); + let store = OperatorStateStore::new(1, tmp.path(), mem, mgr).unwrap(); + + let result = store.get_batches(b"no-such-key").await.unwrap(); + assert!(result.is_empty()); + } + + #[tokio::test] + async fn test_remove_batches() { + let (tmp, mem, mgr, _pool) = setup(); + let store = OperatorStateStore::new(1, tmp.path(), mem, mgr).unwrap(); + + let key = b"key-del".to_vec(); + store.put(key.clone(), make_batch(&[42])).await.unwrap(); + + store.remove_batches(key.clone()).unwrap(); + + let result = store.get_batches(&key).await.unwrap(); + assert!(result.is_empty()); + } + + #[tokio::test] + async fn test_remove_does_not_affect_other_keys() { + let (tmp, mem, mgr, _pool) = setup(); + let store = OperatorStateStore::new(1, tmp.path(), mem, mgr).unwrap(); + + let k1 = b"key-1".to_vec(); + let k2 = b"key-2".to_vec(); + store.put(k1.clone(), make_batch(&[1])).await.unwrap(); + store.put(k2.clone(), make_batch(&[2])).await.unwrap(); + + store.remove_batches(k1.clone()).unwrap(); + + assert!(store.get_batches(&k1).await.unwrap().is_empty()); + assert_eq!(store.get_batches(&k2).await.unwrap().len(), 1); + } + + #[tokio::test] + async fn test_snapshot_epoch_advances() { + let (tmp, mem, mgr, _pool) = setup(); + let store = OperatorStateStore::new(1, tmp.path(), mem, mgr).unwrap(); + + store.put(b"k".to_vec(), make_batch(&[1])).await.unwrap(); + store.snapshot_epoch(5).unwrap(); + + assert_eq!(store.current_epoch.load(Ordering::Acquire), 6); + } + + #[tokio::test] + async fn test_data_survives_snapshot_via_spill() { + let (tmp, mem, mgr, _pool) = setup(); + let store = OperatorStateStore::new(1, tmp.path(), mem, mgr).unwrap(); + + let key = b"persist".to_vec(); + store.put(key.clone(), make_batch(&[99])).await.unwrap(); + store.snapshot_epoch(1).unwrap(); + + // snapshot_epoch triggers a spill; wait for the background worker to + // flush the data to disk so get_batches can read it from parquet files. + tokio::time::sleep(std::time::Duration::from_millis(200)).await; + + let result = store.get_batches(&key).await.unwrap(); + assert!(!result.is_empty()); + let col = result[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col.values(), &[99]); + } + + #[tokio::test] + async fn test_tombstone_hides_immutable_data() { + let (tmp, mem, mgr, _pool) = setup(); + let store = OperatorStateStore::new(1, tmp.path(), mem, mgr).unwrap(); + + let key = b"will-die".to_vec(); + store.put(key.clone(), make_batch(&[7])).await.unwrap(); + + // Move to immutable via snapshot + store.snapshot_epoch(1).unwrap(); + + // Tombstone at epoch 2 (> immutable epoch 1) + store.current_epoch.store(2, Ordering::Release); + store.remove_batches(key.clone()).unwrap(); + + let result = store.get_batches(&key).await.unwrap(); + assert!(result.is_empty()); + } + + #[tokio::test] + async fn test_memory_controller_tracking() { + let mem = MemoryController::new(1024, 2048); + assert_eq!(mem.usage_bytes(), 0); + + mem.record_inc(100); + assert_eq!(mem.usage_bytes(), 100); + + mem.record_dec(40); + assert_eq!(mem.usage_bytes(), 60); + + assert!(!mem.should_spill()); + mem.record_inc(1000); + assert!(mem.should_spill()); + } + + #[tokio::test] + async fn test_memory_controller_hard_limit() { + let mem = MemoryController::new(512, 1024); + assert!(!mem.exceeds_hard_limit(500)); + assert!(mem.exceeds_hard_limit(1025)); + + mem.record_inc(800); + assert!(mem.exceeds_hard_limit(300)); + assert!(!mem.exceeds_hard_limit(200)); + } + + #[test] + fn test_extract_epoch() { + let path = PathBuf::from("/tmp/data-epoch-42_uuid-abc123.parquet"); + assert_eq!(extract_epoch(&path), 42); + + let path2 = PathBuf::from("/tmp/tombstone-epoch-100_uuid-def456.parquet"); + assert_eq!(extract_epoch(&path2), 100); + + let path3 = PathBuf::from("/tmp/random-file.parquet"); + assert_eq!(extract_epoch(&path3), 0); + } + + #[test] + fn test_inject_and_strip_partition_key() { + let batch = make_batch(&[1, 2, 3]); + let key = b"pk-test"; + + let injected = inject_partition_key(&batch, key).unwrap(); + assert_eq!(injected.num_columns(), 2); + assert!(injected.schema().index_of(PARTITION_KEY_COL).is_ok()); + + let stripped = filter_and_strip_partition_key(&injected, key) + .unwrap() + .unwrap(); + assert_eq!(stripped.num_columns(), 1); + let col = stripped + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col.values(), &[1, 2, 3]); + } + + #[test] + fn test_filter_partition_key_mismatch() { + let batch = make_batch(&[1, 2]); + let injected = inject_partition_key(&batch, b"pk-a").unwrap(); + + let result = filter_and_strip_partition_key(&injected, b"pk-b").unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_restore_memtable_roundtrip() { + let batch1 = inject_partition_key(&make_batch(&[10]), b"k1").unwrap(); + let batch2 = inject_partition_key(&make_batch(&[20]), b"k2").unwrap(); + let batch3 = inject_partition_key(&make_batch(&[30]), b"k1").unwrap(); + + let restored = + restore_memtable_from_injected_batches(vec![batch1, batch2, batch3]).unwrap(); + + assert_eq!(restored.len(), 2); + assert_eq!(restored[b"k1".as_ref()].len(), 2); + assert_eq!(restored[b"k2".as_ref()].len(), 1); + } + + #[test] + fn test_write_and_read_parquet() { + let tmp = TempDir::new().unwrap(); + let path = tmp.path().join("test.parquet"); + + let batch = make_batch(&[100, 200, 300]); + write_parquet_with_bloom_atomic(&path, std::slice::from_ref(&batch), 1).unwrap(); + + let file = File::open(&path).unwrap(); + let reader = ParquetRecordBatchReaderBuilder::try_new(file) + .unwrap() + .build() + .unwrap(); + + let read_batches: Vec = reader.map(|r| r.unwrap()).collect(); + assert_eq!(read_batches.len(), 1); + let col = read_batches[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col.values(), &[100, 200, 300]); + } + + #[test] + fn test_filter_tombstones_from_batch() { + let batch = make_batch(&[1, 2, 3]); + let key = b"victim"; + let injected = inject_partition_key(&batch, key).unwrap(); + + let mut tombstones: TombstoneMap = HashMap::new(); + tombstones.insert(key.to_vec(), 10); + + // file_epoch <= tombstone epoch => fully filtered + let result = filter_tombstones_from_batch(&injected, &tombstones, 5).unwrap(); + assert!(result.is_none()); + + // file_epoch > tombstone epoch => data survives + let result = filter_tombstones_from_batch(&injected, &tombstones, 15).unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_write_empty_batches_is_noop() { + let tmp = TempDir::new().unwrap(); + let path = tmp.path().join("empty.parquet"); + + write_parquet_with_bloom_atomic(&path, &[], 0).unwrap(); + assert!(!path.exists()); + } +} From 429bb510ee98129e524efdbc7779dd18bb46f85f Mon Sep 17 00:00:00 2001 From: luoluoyuyu Date: Thu, 16 Apr 2026 23:44:42 +0800 Subject: [PATCH 08/26] update --- .github/workflows/verify-package.yml | 4 +--- Makefile | 5 ++--- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/.github/workflows/verify-package.yml b/.github/workflows/verify-package.yml index 10ad364a..085ac61d 100644 --- a/.github/workflows/verify-package.yml +++ b/.github/workflows/verify-package.yml @@ -52,9 +52,7 @@ jobs: libcurl4-openssl-dev \ pkg-config \ libsasl2-dev \ - protobuf-compiler \ - musl-tools - sudo ln -sf /usr/bin/musl-gcc /usr/local/bin/x86_64-linux-musl-gcc + protobuf-compiler - name: Cache Cargo uses: Swatinem/rust-cache@v2 diff --git a/Makefile b/Makefile index e914b376..5d03c30a 100644 --- a/Makefile +++ b/Makefile @@ -32,9 +32,8 @@ OS_NAME := $(shell uname -s) # 2. Configure RUSTFLAGS and target triple per platform DIST_ROOT := dist ifeq ($(OS_NAME), Linux) - # Linux: static-link musl for a truly self-contained, zero-dependency binary - TRIPLE := $(ARCH)-unknown-linux-musl - STATIC_FLAGS := -C target-feature=+crt-static + TRIPLE := $(ARCH)-unknown-linux-gnu + STATIC_FLAGS := else ifeq ($(OS_NAME), Darwin) # macOS: strip symbols but keep dynamic linking (Apple system restriction) TRIPLE := $(ARCH)-apple-darwin From de248a766700c490ff29d20a7b0ed6486ad08523 Mon Sep 17 00:00:00 2001 From: luoluoyuyu Date: Fri, 17 Apr 2026 00:40:55 +0800 Subject: [PATCH 09/26] update --- .../operators/grouping/incremental_aggregate.rs | 6 +++--- src/sql/logical_planner/streaming_planner.rs | 2 +- src/sql/types/data_type.rs | 11 ++++------- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/runtime/streaming/operators/grouping/incremental_aggregate.rs b/src/runtime/streaming/operators/grouping/incremental_aggregate.rs index efe0abbb..346199f6 100644 --- a/src/runtime/streaming/operators/grouping/incremental_aggregate.rs +++ b/src/runtime/streaming/operators/grouping/incremental_aggregate.rs @@ -471,7 +471,7 @@ impl IncrementalAggregatingFunc { state }); - for (idx, v) in agg.state_cols.iter().zip(state.into_iter()) { + for (idx, v) in agg.state_cols.iter().zip(state) { states[*idx].push(v); } } @@ -555,7 +555,7 @@ impl IncrementalAggregatingFunc { .iter() .map(|kb| parser.parse(kb).to_owned()) .collect(); - let mut cols = self.key_converter.convert_rows(rows.into_iter())?; + let mut cols = self.key_converter.convert_rows(rows)?; cols.push(Arc::new(accumulator_builder.finish())); cols.push(Arc::new(args_row_builder.finish())); cols.push(Arc::new(count_builder.finish())); @@ -624,7 +624,7 @@ impl IncrementalAggregatingFunc { mem::take(&mut self.updated_keys).into_iter().unzip(); let mut deleted_keys = vec![]; - for (k, retract) in updated_keys.iter().zip(updated_values.into_iter()) { + for (k, retract) in updated_keys.iter().zip(updated_values) { let append = self.evaluate(&k.0)?; if let Some(v) = retract { diff --git a/src/sql/logical_planner/streaming_planner.rs b/src/sql/logical_planner/streaming_planner.rs index e501695d..4619fb3f 100644 --- a/src/sql/logical_planner/streaming_planner.rs +++ b/src/sql/logical_planner/streaming_planner.rs @@ -341,7 +341,7 @@ impl PlanToGraphVisitor<'_> { let node_index = self.graph.add_node(execution_unit); self.add_index_to_traversal(node_index); - for (source, edge) in input_nodes.into_iter().zip(routing_edges.into_iter()) { + for (source, edge) in input_nodes.into_iter().zip(routing_edges) { self.graph.add_edge(source, node_index, edge); } diff --git a/src/sql/types/data_type.rs b/src/sql/types/data_type.rs index 070324d5..387a4190 100644 --- a/src/sql/types/data_type.rs +++ b/src/sql/types/data_type.rs @@ -98,14 +98,11 @@ fn convert_simple_data_type( } }, SQLDataType::Date => Ok(DataType::Date32), - SQLDataType::Time(None, tz_info) => { + SQLDataType::Time(None, tz_info) if matches!(tz_info, TimezoneInfo::None) - || matches!(tz_info, TimezoneInfo::WithoutTimeZone) - { - Ok(DataType::Time64(TimeUnit::Nanosecond)) - } else { - return plan_err!("Unsupported SQL type {sql_type:?}"); - } + || matches!(tz_info, TimezoneInfo::WithoutTimeZone) => + { + Ok(DataType::Time64(TimeUnit::Nanosecond)) } SQLDataType::Numeric(exact_number_info) | SQLDataType::Decimal(exact_number_info) => { let (precision, scale) = match *exact_number_info { From 55bdff855965731ca2d5a9565dc53ae75bdbb1e0 Mon Sep 17 00:00:00 2001 From: luoluoyuyu Date: Fri, 17 Apr 2026 23:06:16 +0800 Subject: [PATCH 10/26] update --- src/runtime/streaming/state/operator_state.rs | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/runtime/streaming/state/operator_state.rs b/src/runtime/streaming/state/operator_state.rs index 86cb3729..84838ca2 100644 --- a/src/runtime/streaming/state/operator_state.rs +++ b/src/runtime/streaming/state/operator_state.rs @@ -170,6 +170,12 @@ impl OperatorStateStore { Ok(()) } + pub async fn await_spill_complete(&self) { + while self.is_spilling.load(Ordering::SeqCst) { + self.spill_notify.notified().await; + } + } + fn downgrade_active_table(&self, epoch: u64) { let mut active_guard = self.active_table.write(); if active_guard.is_empty() { @@ -833,10 +839,7 @@ mod tests { let key = b"persist".to_vec(); store.put(key.clone(), make_batch(&[99])).await.unwrap(); store.snapshot_epoch(1).unwrap(); - - // snapshot_epoch triggers a spill; wait for the background worker to - // flush the data to disk so get_batches can read it from parquet files. - tokio::time::sleep(std::time::Duration::from_millis(200)).await; + store.await_spill_complete().await; let result = store.get_batches(&key).await.unwrap(); assert!(!result.is_empty()); From 0602512498d1bc72bf95feb09491cdc23841d0b5 Mon Sep 17 00:00:00 2001 From: luoluoyuyu Date: Sun, 19 Apr 2026 00:11:08 +0800 Subject: [PATCH 11/26] update --- conf/config.yaml | 13 + src/config/global_config.rs | 6 +- src/config/mod.rs | 1 + src/config/system.rs | 233 ++++++++++++++++++ src/coordinator/execution/executor.rs | 2 +- src/coordinator/runtime_context.rs | 2 +- src/runtime/memory/block.rs | 75 ++++++ src/runtime/memory/error.rs | 35 +++ src/runtime/memory/global.rs | 56 +++++ src/runtime/{streaming => }/memory/mod.rs | 20 ++ src/runtime/{streaming => }/memory/pool.rs | 55 +++-- src/runtime/{streaming => }/memory/ticket.rs | 18 +- src/runtime/mod.rs | 4 +- src/runtime/streaming/api/context.rs | 33 +-- src/runtime/streaming/job/job_manager.rs | 29 +-- src/runtime/streaming/mod.rs | 1 - .../grouping/incremental_aggregate.rs | 1 - .../operators/joins/join_instance.rs | 1 - .../operators/joins/join_with_expiration.rs | 1 - .../windows/session_aggregating_window.rs | 1 - .../windows/sliding_aggregating_window.rs | 1 - .../windows/tumbling_aggregating_window.rs | 1 - .../operators/windows/window_function.rs | 1 - src/runtime/streaming/protocol/event.rs | 2 +- src/runtime/streaming/state/metrics.rs | 4 +- src/runtime/streaming/state/mod.rs | 2 +- src/runtime/streaming/state/operator_state.rs | 173 +++++++------ .../buffer_and_event/buffer_or_event.rs | 0 .../{ => wasm}/buffer_and_event/mod.rs | 0 .../buffer_and_event/stream_element/mod.rs | 0 .../stream_element/stream_element.rs | 0 src/runtime/wasm/input/input_protocol.rs | 2 +- src/runtime/wasm/input/input_provider.rs | 2 +- src/runtime/wasm/input/input_runner.rs | 8 +- src/runtime/wasm/input/interface.rs | 4 +- .../input/protocol/kafka/kafka_protocol.rs | 2 +- src/runtime/wasm/mod.rs | 3 + src/runtime/wasm/output/interface.rs | 4 +- src/runtime/wasm/output/output_protocol.rs | 2 +- src/runtime/wasm/output/output_provider.rs | 2 +- src/runtime/wasm/output/output_runner.rs | 8 +- .../output/protocol/kafka/kafka_protocol.rs | 2 +- src/runtime/wasm/processor/wasm/wasm_host.rs | 6 +- .../wasm/processor/wasm/wasm_processor.rs | 4 +- .../processor/wasm/wasm_processor_trait.rs | 2 +- src/runtime/wasm/processor/wasm/wasm_task.rs | 12 +- src/runtime/{ => wasm}/task/builder/mod.rs | 0 .../{ => wasm}/task/builder/processor/mod.rs | 4 +- .../{ => wasm}/task/builder/python/mod.rs | 6 +- .../{ => wasm}/task/builder/sink/mod.rs | 2 +- .../{ => wasm}/task/builder/source/mod.rs | 2 +- .../{ => wasm}/task/builder/task_builder.rs | 12 +- .../{ => wasm}/task/control_mailbox.rs | 0 src/runtime/{ => wasm}/task/lifecycle.rs | 4 +- src/runtime/{ => wasm}/task/mod.rs | 0 .../{ => wasm}/task/processor_config.rs | 2 +- src/runtime/{ => wasm}/task/yaml_keys.rs | 0 .../{ => wasm}/taskexecutor/init_context.rs | 2 +- src/runtime/{ => wasm}/taskexecutor/mod.rs | 0 .../{ => wasm}/taskexecutor/task_manager.rs | 4 +- src/server/initializer.rs | 79 +++++- 61 files changed, 721 insertions(+), 230 deletions(-) create mode 100644 src/config/system.rs create mode 100644 src/runtime/memory/block.rs create mode 100644 src/runtime/memory/error.rs create mode 100644 src/runtime/memory/global.rs rename src/runtime/{streaming => }/memory/mod.rs (54%) rename src/runtime/{streaming => }/memory/pool.rs (56%) rename src/runtime/{streaming => }/memory/ticket.rs (71%) rename src/runtime/{ => wasm}/buffer_and_event/buffer_or_event.rs (100%) rename src/runtime/{ => wasm}/buffer_and_event/mod.rs (100%) rename src/runtime/{ => wasm}/buffer_and_event/stream_element/mod.rs (100%) rename src/runtime/{ => wasm}/buffer_and_event/stream_element/stream_element.rs (100%) rename src/runtime/{ => wasm}/task/builder/mod.rs (100%) rename src/runtime/{ => wasm}/task/builder/processor/mod.rs (97%) rename src/runtime/{ => wasm}/task/builder/python/mod.rs (95%) rename src/runtime/{ => wasm}/task/builder/sink/mod.rs (97%) rename src/runtime/{ => wasm}/task/builder/source/mod.rs (97%) rename src/runtime/{ => wasm}/task/builder/task_builder.rs (94%) rename src/runtime/{ => wasm}/task/control_mailbox.rs (100%) rename src/runtime/{ => wasm}/task/lifecycle.rs (97%) rename src/runtime/{ => wasm}/task/mod.rs (100%) rename src/runtime/{ => wasm}/task/processor_config.rs (99%) rename src/runtime/{ => wasm}/task/yaml_keys.rs (100%) rename src/runtime/{ => wasm}/taskexecutor/init_context.rs (97%) rename src/runtime/{ => wasm}/taskexecutor/mod.rs (100%) rename src/runtime/{ => wasm}/taskexecutor/task_manager.rs (98%) diff --git a/conf/config.yaml b/conf/config.yaml index 9d0f625e..cfb71d02 100644 --- a/conf/config.yaml +++ b/conf/config.yaml @@ -49,6 +49,19 @@ wasm: # When cache exceeds this size, least recently used items will be evicted max_cache_size: 104857600 +# Streaming Runtime Configuration +streaming: + # Global memory pool size for the streaming runtime (network buffering, backpressure). + # When not set, auto-detected as 70% of physical memory. + # Fallback: 256 MiB if detection fails. + # max_memory_bytes: 268435456 + + # Memory budget per stateful operator (aggregation, join, window). + # Each operator gets its own independent memory controller with this limit. + # When exceeded, the operator spills state to disk automatically. + # Default: 67108864 (64 MiB) + per_operator_state_memory_bytes: 67108864 + # State Storage Configuration # Used to store runtime state data for tasks state_storage: diff --git a/src/config/global_config.rs b/src/config/global_config.rs index c76bf4b0..90332c25 100644 --- a/src/config/global_config.rs +++ b/src/config/global_config.rs @@ -21,9 +21,9 @@ use crate::config::wasm_config::WasmConfig; #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct StreamingConfig { - /// Maximum heap memory (in bytes) available to the streaming runtime's memory pool. - /// Defaults to 256 MiB when absent. - pub max_memory_bytes: Option, + pub max_memory_bytes: Option, + /// Total bytes for the global operator-state [`MemoryPool`](crate::runtime::memory::MemoryPool) (all stores share this quota). + pub per_operator_state_memory_bytes: Option, } #[derive(Debug, Clone, Serialize, Deserialize, Default)] diff --git a/src/config/mod.rs b/src/config/mod.rs index f08051af..489063e1 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -17,6 +17,7 @@ pub mod paths; pub mod python_config; pub mod service_config; pub mod storage; +pub mod system; pub mod wasm_config; pub use global_config::GlobalConfig; diff --git a/src/config/system.rs b/src/config/system.rs new file mode 100644 index 00000000..d7a37ddf --- /dev/null +++ b/src/config/system.rs @@ -0,0 +1,233 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::io; + +pub struct SystemMemoryInfo { + pub total_physical: u64, + pub available_physical: u64, + pub total_virtual: u64, + pub available_virtual: u64, +} + +pub fn system_memory_info() -> io::Result { + sys::system_memory_info() +} + +#[cfg(target_os = "linux")] +mod sys { + use super::SystemMemoryInfo; + use std::io; + + pub fn system_memory_info() -> io::Result { + let content = std::fs::read_to_string("/proc/meminfo")?; + + let mut total_physical: Option = None; + let mut available_physical: Option = None; + let mut swap_total: u64 = 0; + let mut swap_free: u64 = 0; + + for line in content.lines() { + if let Some(v) = parse_meminfo_kb(line, "MemTotal:") { + total_physical = Some(v); + } else if let Some(v) = parse_meminfo_kb(line, "MemAvailable:") { + available_physical = Some(v); + } else if let Some(v) = parse_meminfo_kb(line, "SwapTotal:") { + swap_total = v; + } else if let Some(v) = parse_meminfo_kb(line, "SwapFree:") { + swap_free = v; + } + } + + let total_phys = total_physical.ok_or_else(|| { + io::Error::new( + io::ErrorKind::NotFound, + "MemTotal not found in /proc/meminfo", + ) + })?; + let avail_phys = available_physical.unwrap_or(0); + + Ok(SystemMemoryInfo { + total_physical: total_phys, + available_physical: avail_phys, + total_virtual: total_phys + swap_total, + available_virtual: avail_phys + swap_free, + }) + } + + fn parse_meminfo_kb(line: &str, prefix: &str) -> Option { + let rest = line.strip_prefix(prefix)?; + let kb: u64 = rest.trim().trim_end_matches("kB").trim().parse().ok()?; + Some(kb * 1024) + } +} + +#[cfg(target_os = "macos")] +mod sys { + use super::SystemMemoryInfo; + use std::io; + + pub fn system_memory_info() -> io::Result { + let total_physical = sysctl_u64("hw.memsize")?; + + let page_size = sysctl_u64("hw.pagesize").unwrap_or(4096); + let vm_stats = read_vm_stat()?; + + let free_pages = vm_stats.free + vm_stats.inactive + vm_stats.purgeable; + let available_physical = free_pages * page_size; + + let swap = read_swap_usage(); + let swap_total = swap.0; + let swap_free = swap_total.saturating_sub(swap.1); + + Ok(SystemMemoryInfo { + total_physical, + available_physical, + total_virtual: total_physical + swap_total, + available_virtual: available_physical + swap_free, + }) + } + + fn sysctl_u64(name: &str) -> io::Result { + let output = std::process::Command::new("sysctl") + .arg("-n") + .arg(name) + .output()?; + if !output.status.success() { + return Err(io::Error::new( + io::ErrorKind::Other, + format!("sysctl {name} failed"), + )); + } + String::from_utf8_lossy(&output.stdout) + .trim() + .parse() + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) + } + + struct VmPages { + free: u64, + inactive: u64, + purgeable: u64, + } + + fn read_vm_stat() -> io::Result { + let output = std::process::Command::new("vm_stat").output()?; + let text = String::from_utf8_lossy(&output.stdout); + + let mut free = 0u64; + let mut inactive = 0u64; + let mut purgeable = 0u64; + + for line in text.lines() { + if let Some(v) = parse_vm_stat_line(line, "Pages free") { + free = v; + } else if let Some(v) = parse_vm_stat_line(line, "Pages inactive") { + inactive = v; + } else if let Some(v) = parse_vm_stat_line(line, "Pages purgeable") { + purgeable = v; + } + } + + Ok(VmPages { + free, + inactive, + purgeable, + }) + } + + fn parse_vm_stat_line(line: &str, key: &str) -> Option { + if !line.contains(key) { + return None; + } + let val_str = line.rsplit(':').next()?.trim().trim_end_matches('.'); + val_str.parse().ok() + } + + fn read_swap_usage() -> (u64, u64) { + let output = match std::process::Command::new("sysctl") + .arg("-n") + .arg("vm.swapusage") + .output() + { + Ok(o) => o, + Err(_) => return (0, 0), + }; + let text = String::from_utf8_lossy(&output.stdout); + let mut total = 0u64; + let mut used = 0u64; + for part in text.split_whitespace() { + if let Some(mb_str) = part.strip_suffix("M") { + if let Ok(mb) = mb_str.parse::() { + if total == 0 { + total = (mb * 1024.0 * 1024.0) as u64; + } else if used == 0 { + used = (mb * 1024.0 * 1024.0) as u64; + } + } + } + } + (total, used) + } +} + +#[cfg(target_os = "windows")] +mod sys { + use super::SystemMemoryInfo; + use std::io; + + #[repr(C)] + struct MemoryStatusEx { + dw_length: u32, + dw_memory_load: u32, + ull_total_phys: u64, + ull_avail_phys: u64, + ull_total_page_file: u64, + ull_avail_page_file: u64, + ull_total_virtual: u64, + ull_avail_virtual: u64, + ull_avail_extended_virtual: u64, + } + + extern "system" { + fn GlobalMemoryStatusEx(lpBuffer: *mut MemoryStatusEx) -> i32; + } + + pub fn system_memory_info() -> io::Result { + unsafe { + let mut status = std::mem::zeroed::(); + status.dw_length = std::mem::size_of::() as u32; + if GlobalMemoryStatusEx(&mut status) == 0 { + return Err(io::Error::last_os_error()); + } + Ok(SystemMemoryInfo { + total_physical: status.ull_total_phys, + available_physical: status.ull_avail_phys, + total_virtual: status.ull_total_virtual, + available_virtual: status.ull_avail_virtual, + }) + } + } +} + +#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))] +mod sys { + use super::SystemMemoryInfo; + use std::io; + + pub fn system_memory_info() -> io::Result { + Err(io::Error::new( + io::ErrorKind::Unsupported, + "memory detection not supported on this platform", + )) + } +} diff --git a/src/coordinator/execution/executor.rs b/src/coordinator/execution/executor.rs index 7dc3c0ff..0e7a79b1 100644 --- a/src/coordinator/execution/executor.rs +++ b/src/coordinator/execution/executor.rs @@ -30,7 +30,7 @@ use crate::coordinator::plan::{ use crate::coordinator::statement::{ConfigSource, FunctionSource}; use crate::runtime::streaming::job::JobManager; use crate::runtime::streaming::protocol::control::StopMode; -use crate::runtime::taskexecutor::TaskManager; +use crate::runtime::wasm::taskexecutor::TaskManager; use crate::sql::schema::show_create_catalog_table; use crate::sql::schema::table::Table as CatalogTable; use crate::storage::stream_catalog::CatalogManager; diff --git a/src/coordinator/runtime_context.rs b/src/coordinator/runtime_context.rs index 5d671b98..21b9d876 100644 --- a/src/coordinator/runtime_context.rs +++ b/src/coordinator/runtime_context.rs @@ -17,7 +17,7 @@ use std::sync::Arc; use anyhow::Result; use crate::runtime::streaming::job::JobManager; -use crate::runtime::taskexecutor::TaskManager; +use crate::runtime::wasm::taskexecutor::TaskManager; use crate::sql::schema::StreamSchemaProvider; use crate::storage::stream_catalog::CatalogManager; diff --git a/src/runtime/memory/block.rs b/src/runtime/memory/block.rs new file mode 100644 index 00000000..18f30de5 --- /dev/null +++ b/src/runtime/memory/block.rs @@ -0,0 +1,75 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; + +use super::pool::MemoryPool; +use super::ticket::MemoryTicket; + +#[derive(Debug)] +pub struct MemoryBlock { + capacity: u64, + available_bytes: AtomicU64, + pool: Arc, +} + +impl MemoryBlock { + pub(crate) fn new(capacity: u64, pool: Arc) -> Arc { + Arc::new(Self { + capacity, + available_bytes: AtomicU64::new(capacity), + pool, + }) + } + + pub fn try_allocate(self: &Arc, bytes: u64) -> Option { + if bytes == 0 { + return Some(MemoryTicket::new(0, self.clone())); + } + + let mut current_available = self.available_bytes.load(Ordering::Acquire); + loop { + if current_available < bytes { + return None; + } + + match self.available_bytes.compare_exchange_weak( + current_available, + current_available - bytes, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => return Some(MemoryTicket::new(bytes, self.clone())), + Err(actual) => current_available = actual, + } + } + } + + #[inline] + pub fn available_bytes(&self) -> u64 { + self.available_bytes.load(Ordering::Relaxed) + } + + pub(crate) fn release_ticket(&self, bytes: u64) { + if bytes > 0 { + self.available_bytes.fetch_add(bytes, Ordering::Release); + } + } +} + +impl Drop for MemoryBlock { + fn drop(&mut self) { + self.pool.release_block(self.capacity); + } +} diff --git a/src/runtime/memory/error.rs b/src/runtime/memory/error.rs new file mode 100644 index 00000000..a5d5152f --- /dev/null +++ b/src/runtime/memory/error.rs @@ -0,0 +1,35 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MemoryError { + AlreadyInitialized, + Uninitialized, +} + +impl fmt::Display for MemoryError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + MemoryError::AlreadyInitialized => { + write!(f, "Global memory pool is already initialized") + } + MemoryError::Uninitialized => { + write!(f, "Global memory pool is not initialized") + } + } + } +} + +impl std::error::Error for MemoryError {} diff --git a/src/runtime/memory/global.rs b/src/runtime/memory/global.rs new file mode 100644 index 00000000..97a38d1f --- /dev/null +++ b/src/runtime/memory/global.rs @@ -0,0 +1,56 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::{Arc, OnceLock}; + +use super::error::MemoryError; +use super::pool::MemoryPool; + +static GLOBAL_POOL: OnceLock> = OnceLock::new(); +static GLOBAL_STATE_POOL: OnceLock> = OnceLock::new(); + +pub fn init_global_memory_pool(max_bytes: u64) -> Result<(), MemoryError> { + GLOBAL_POOL + .set(MemoryPool::new(max_bytes)) + .map_err(|_| MemoryError::AlreadyInitialized) +} + +pub fn try_global_memory_pool() -> Result, MemoryError> { + GLOBAL_POOL.get().cloned().ok_or(MemoryError::Uninitialized) +} + +#[inline] +pub fn global_memory_pool() -> Arc { + try_global_memory_pool().expect("Global streaming pool must be initialized before use") +} + +pub fn init_global_state_memory_pool(max_bytes: u64) -> Result<(), MemoryError> { + GLOBAL_STATE_POOL + .set(MemoryPool::new(max_bytes)) + .map_err(|_| MemoryError::AlreadyInitialized) +} + +pub fn try_global_state_memory_pool() -> Result, MemoryError> { + GLOBAL_STATE_POOL.get().cloned().ok_or(MemoryError::Uninitialized) +} + +#[inline] +pub fn global_state_memory_pool() -> Arc { + try_global_state_memory_pool().expect("Global state pool must be initialized before use") +} + +pub fn get_memory_metrics() -> (Option<(u64, u64)>, Option<(u64, u64)>) { + let stream_metrics = GLOBAL_POOL.get().map(|p| p.usage_metrics()); + let state_metrics = GLOBAL_STATE_POOL.get().map(|p| p.usage_metrics()); + (stream_metrics, state_metrics) +} diff --git a/src/runtime/streaming/memory/mod.rs b/src/runtime/memory/mod.rs similarity index 54% rename from src/runtime/streaming/memory/mod.rs rename to src/runtime/memory/mod.rs index 45fc3194..5a50028d 100644 --- a/src/runtime/streaming/memory/mod.rs +++ b/src/runtime/memory/mod.rs @@ -1,5 +1,6 @@ // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. +// // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 @@ -10,8 +11,27 @@ // See the License for the specific language governing permissions and // limitations under the License. +use arrow_array::RecordBatch; + +mod block; +mod error; +pub mod global; pub mod pool; pub mod ticket; +#[allow(unused_imports)] +pub use block::MemoryBlock; +#[allow(unused_imports)] +pub use error::MemoryError; +#[allow(unused_imports)] +pub use global::{ + get_memory_metrics, global_memory_pool, global_state_memory_pool, init_global_memory_pool, + init_global_state_memory_pool, try_global_memory_pool, try_global_state_memory_pool, +}; pub use pool::MemoryPool; pub use ticket::MemoryTicket; + +#[inline] +pub fn get_array_memory_size(batch: &RecordBatch) -> u64 { + RecordBatch::get_array_memory_size(batch) as u64 +} diff --git a/src/runtime/streaming/memory/pool.rs b/src/runtime/memory/pool.rs similarity index 56% rename from src/runtime/streaming/memory/pool.rs rename to src/runtime/memory/pool.rs index b6a06ad2..3261c9b0 100644 --- a/src/runtime/streaming/memory/pool.rs +++ b/src/runtime/memory/pool.rs @@ -1,5 +1,6 @@ // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. +// // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 @@ -12,47 +13,48 @@ use parking_lot::Mutex; use std::sync::Arc; -use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicU64, Ordering}; use tokio::sync::Notify; use tracing::{debug, warn}; -use super::ticket::MemoryTicket; +use super::block::MemoryBlock; #[derive(Debug)] pub struct MemoryPool { - max_bytes: usize, - used_bytes: AtomicUsize, - available_bytes: Mutex, + max_bytes: u64, + used_bytes: AtomicU64, + available_bytes: Mutex, notify: Notify, } impl MemoryPool { - pub fn new(max_bytes: usize) -> Arc { + pub fn new(max_bytes: u64) -> Arc { Arc::new(Self { max_bytes, - used_bytes: AtomicUsize::new(0), + used_bytes: AtomicU64::new(0), available_bytes: Mutex::new(max_bytes), notify: Notify::new(), }) } - pub fn usage_metrics(&self) -> (usize, usize) { + pub fn usage_metrics(&self) -> (u64, u64) { (self.used_bytes.load(Ordering::Relaxed), self.max_bytes) } - pub async fn request_memory(self: &Arc, bytes: usize) -> MemoryTicket { + pub async fn request_block(self: &Arc, bytes: u64) -> Arc { if bytes == 0 { - return MemoryTicket::new(0, self.clone()); + return MemoryBlock::new(0, self.clone()); } if bytes > self.max_bytes { warn!( - "Requested memory ({} B) exceeds total pool size ({} B)! \ - Permitting to avoid pipeline deadlock, but OOM risk is critical.", - bytes, self.max_bytes + request_bytes = bytes, + max_bytes = self.max_bytes, + "Requested memory block exceeds total pool size! \ + Permitting to avoid pipeline deadlock, but critical OOM risk exists." ); self.used_bytes.fetch_add(bytes, Ordering::Relaxed); - return MemoryTicket::new(bytes, self.clone()); + return MemoryBlock::new(bytes, self.clone()); } loop { @@ -61,19 +63,32 @@ impl MemoryPool { if *available >= bytes { *available -= bytes; self.used_bytes.fetch_add(bytes, Ordering::Relaxed); - return MemoryTicket::new(bytes, self.clone()); + return MemoryBlock::new(bytes, self.clone()); } } - debug!( - "Backpressure engaged: waiting for {} bytes to be freed...", - bytes - ); + debug!(bytes = bytes, "Global backpressure engaged: waiting for memory..."); self.notify.notified().await; } } - pub(crate) fn release(&self, bytes: usize) { + pub fn force_reserve(&self, bytes: u64) { + if bytes == 0 { + return; + } + let mut available = self.available_bytes.lock(); + *available = available.saturating_sub(bytes); + self.used_bytes.fetch_add(bytes, Ordering::Relaxed); + } + + pub fn force_release(&self, bytes: u64) { + if bytes == 0 { + return; + } + self.release_block(bytes); + } + + pub(crate) fn release_block(&self, bytes: u64) { if bytes == 0 { return; } diff --git a/src/runtime/streaming/memory/ticket.rs b/src/runtime/memory/ticket.rs similarity index 71% rename from src/runtime/streaming/memory/ticket.rs rename to src/runtime/memory/ticket.rs index cb105be0..24362e2f 100644 --- a/src/runtime/streaming/memory/ticket.rs +++ b/src/runtime/memory/ticket.rs @@ -1,5 +1,6 @@ // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. +// // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 @@ -12,22 +13,27 @@ use std::sync::Arc; -use super::pool::MemoryPool; +use super::block::MemoryBlock; #[derive(Debug)] pub struct MemoryTicket { - bytes: usize, - pool: Arc, + bytes: u64, + block: Arc, } impl MemoryTicket { - pub(crate) fn new(bytes: usize, pool: Arc) -> Self { - Self { bytes, pool } + pub(crate) fn new(bytes: u64, block: Arc) -> Self { + Self { bytes, block } + } + + #[inline] + pub fn bytes(&self) -> u64 { + self.bytes } } impl Drop for MemoryTicket { fn drop(&mut self) { - self.pool.release(self.bytes); + self.block.release_ticket(self.bytes); } } diff --git a/src/runtime/mod.rs b/src/runtime/mod.rs index 1ba5e2a3..8c72b507 100644 --- a/src/runtime/mod.rs +++ b/src/runtime/mod.rs @@ -12,11 +12,9 @@ // Runtime module -pub mod buffer_and_event; pub mod common; +pub mod memory; pub mod streaming; -pub mod task; -pub mod taskexecutor; pub mod util; pub mod wasm; diff --git a/src/runtime/streaming/api/context.rs b/src/runtime/streaming/api/context.rs index 27babd56..4e5af1a8 100644 --- a/src/runtime/streaming/api/context.rs +++ b/src/runtime/streaming/api/context.rs @@ -14,13 +14,13 @@ use std::path::PathBuf; use std::sync::Arc; use std::time::{Duration, SystemTime}; -use anyhow::{Context, Result}; +use anyhow::{anyhow, Context, Result}; use arrow_array::RecordBatch; -use crate::runtime::streaming::memory::MemoryPool; +use crate::runtime::memory::{get_array_memory_size, MemoryPool}; use crate::runtime::streaming::network::endpoint::PhysicalSender; use crate::runtime::streaming::protocol::event::{StreamEvent, TrackedEvent}; -use crate::runtime::streaming::state::{IoManager, MemoryController}; +use crate::runtime::streaming::state::IoManager; #[derive(Debug, Clone)] pub struct TaskContextConfig { @@ -55,7 +55,7 @@ pub struct TaskContext { /// Downstream physical senders (outbound edges). downstream_senders: Vec, - /// Global memory pool for back-pressure and accounting. + /// Job-wide shared pool; memory is accounted only when [`Self::collect`] / [`Self::collect_keyed`] run. memory_pool: Arc, /// Latest aligned event-time watermark for this subtask. @@ -64,13 +64,7 @@ pub struct TaskContext { /// Subtask-level tunables. config: TaskContextConfig, - /// Root directory for operator state persistence (LSM-Tree data/tombstone files). pub state_dir: PathBuf, - - /// Shared memory controller for state engine back-pressure. - pub memory_controller: Arc, - - /// I/O thread pool handle for background spill/compaction. pub io_manager: IoManager, /// Last globally-committed safe epoch for crash recovery. @@ -86,7 +80,6 @@ impl TaskContext { parallelism: u32, downstream_senders: Vec, memory_pool: Arc, - memory_controller: Arc, io_manager: IoManager, state_dir: PathBuf, safe_epoch: u64, @@ -107,7 +100,6 @@ impl TaskContext { current_watermark: None, config: TaskContextConfig::default(), state_dir, - memory_controller, io_manager, safe_epoch, } @@ -150,13 +142,19 @@ impl TaskContext { // ------------------------------------------------------------------------- /// Fan-out a data batch to all downstreams (forward / broadcast). + /// + /// Back-pressure and memory accounting happen here via [`MemoryPool::request_block`], not + /// when building the pipeline. pub async fn collect(&self, batch: RecordBatch) -> Result<()> { if self.downstream_senders.is_empty() { return Ok(()); } - let bytes_required = batch.get_array_memory_size(); - let ticket = self.memory_pool.request_memory(bytes_required).await; + let bytes_required = get_array_memory_size(&batch); + let block = self.memory_pool.request_block(bytes_required).await; + let ticket = block + .try_allocate(bytes_required) + .ok_or_else(|| anyhow!("memory block allocation failed"))?; let tracked_event = TrackedEvent::new(StreamEvent::Data(batch), Some(ticket)); self.broadcast_event(tracked_event).await @@ -169,8 +167,11 @@ impl TaskContext { return Ok(()); } - let bytes_required = batch.get_array_memory_size(); - let ticket = self.memory_pool.request_memory(bytes_required).await; + let bytes_required = get_array_memory_size(&batch); + let block = self.memory_pool.request_block(bytes_required).await; + let ticket = block + .try_allocate(bytes_required) + .ok_or_else(|| anyhow!("memory block allocation failed"))?; let event = TrackedEvent::new(StreamEvent::Data(batch), Some(ticket)); let target_idx = (key_hash as usize) % num_downstreams; diff --git a/src/runtime/streaming/job/job_manager.rs b/src/runtime/streaming/job/job_manager.rs index b0839e4a..3b8aa0f8 100644 --- a/src/runtime/streaming/job/job_manager.rs +++ b/src/runtime/streaming/job/job_manager.rs @@ -32,11 +32,11 @@ use crate::runtime::streaming::job::edge_manager::EdgeManager; use crate::runtime::streaming::job::models::{ PhysicalExecutionGraph, PhysicalPipeline, PipelineStatus, StreamingJobRollupStatus, }; -use crate::runtime::streaming::memory::MemoryPool; +use crate::runtime::memory::global_memory_pool; use crate::runtime::streaming::network::endpoint::{BoxedEventStream, PhysicalSender}; use crate::runtime::streaming::protocol::control::{ControlCommand, JobMasterEvent, StopMode}; use crate::runtime::streaming::protocol::event::CheckpointBarrier; -use crate::runtime::streaming::state::{IoManager, IoPool, MemoryController, NoopMetricsCollector}; +use crate::runtime::streaming::state::{IoManager, IoPool, NoopMetricsCollector}; use crate::storage::stream_catalog::CatalogManager; #[derive(Debug, Clone)] @@ -69,6 +69,8 @@ pub struct StateConfig { pub max_background_compactions: usize, pub soft_limit_ratio: f64, pub checkpoint_interval_ms: u64, + /// Total bytes shared by all [`crate::runtime::streaming::state::OperatorStateStore`] (global pool). + pub per_operator_memory_bytes: u64, } impl Default for StateConfig { @@ -78,6 +80,7 @@ impl Default for StateConfig { max_background_compactions: 2, soft_limit_ratio: 0.7, checkpoint_interval_ms: 10_000, + per_operator_memory_bytes: 64 * 1024 * 1024, } } } @@ -87,11 +90,6 @@ static GLOBAL_JOB_MANAGER: OnceLock> = OnceLock::new(); pub struct JobManager { active_jobs: Arc>>, operator_factory: Arc, - memory_pool: Arc, - - #[allow(dead_code)] - memory_controller: Arc, - #[allow(dead_code)] io_manager_client: IoManager, io_pool: Mutex>, state_base_dir: PathBuf, @@ -120,13 +118,9 @@ impl PipelineRunner { impl JobManager { pub fn new( operator_factory: Arc, - max_memory_bytes: usize, state_base_dir: impl AsRef, state_config: StateConfig, ) -> Result { - let soft_limit_bytes = (max_memory_bytes as f64 * state_config.soft_limit_ratio) as usize; - let memory_controller = MemoryController::new(soft_limit_bytes, max_memory_bytes); - let metrics = Arc::new(NoopMetricsCollector); let (io_pool, io_manager_client) = IoPool::try_new( state_config.max_background_spills, @@ -138,8 +132,6 @@ impl JobManager { Ok(Self { active_jobs: Arc::new(RwLock::new(HashMap::new())), operator_factory, - memory_pool: MemoryPool::new(max_memory_bytes), - memory_controller, io_manager_client, io_pool: Mutex::new(Some(io_pool)), state_base_dir: state_base_dir.as_ref().to_path_buf(), @@ -149,17 +141,11 @@ impl JobManager { pub fn init( factory: Arc, - memory_bytes: usize, state_base_dir: PathBuf, state_config: StateConfig, ) -> Result<()> { GLOBAL_JOB_MANAGER - .set(Arc::new(Self::new( - factory, - memory_bytes, - state_base_dir, - state_config, - )?)) + .set(Arc::new(Self::new(factory, state_base_dir, state_config)?)) .map_err(|_| anyhow!("JobManager singleton already initialized")) } @@ -485,8 +471,7 @@ impl JobManager { subtask_index, parallelism, physical_outboxes, - Arc::clone(&self.memory_pool), - Arc::clone(&self.memory_controller), + Arc::clone(&global_memory_pool()), self.io_manager_client.clone(), job_state_dir.to_path_buf(), recovery_epoch, diff --git a/src/runtime/streaming/mod.rs b/src/runtime/streaming/mod.rs index 0e4e6758..b092c85d 100644 --- a/src/runtime/streaming/mod.rs +++ b/src/runtime/streaming/mod.rs @@ -19,7 +19,6 @@ pub mod execution; pub mod factory; pub mod format; pub mod job; -pub mod memory; pub mod network; pub mod operators; pub mod protocol; diff --git a/src/runtime/streaming/operators/grouping/incremental_aggregate.rs b/src/runtime/streaming/operators/grouping/incremental_aggregate.rs index 346199f6..0b8e9c79 100644 --- a/src/runtime/streaming/operators/grouping/incremental_aggregate.rs +++ b/src/runtime/streaming/operators/grouping/incremental_aggregate.rs @@ -725,7 +725,6 @@ impl Operator for IncrementalAggregatingFunc { let store = OperatorStateStore::new( ctx.pipeline_id, ctx.state_dir.clone(), - ctx.memory_controller.clone(), ctx.io_manager.clone(), ) .map_err(|e| anyhow!("Failed to init state store: {e}"))?; diff --git a/src/runtime/streaming/operators/joins/join_instance.rs b/src/runtime/streaming/operators/joins/join_instance.rs index bfb6c416..e8474494 100644 --- a/src/runtime/streaming/operators/joins/join_instance.rs +++ b/src/runtime/streaming/operators/joins/join_instance.rs @@ -215,7 +215,6 @@ impl Operator for InstantJoinOperator { let store = OperatorStateStore::new( ctx.pipeline_id, ctx.state_dir.clone(), - ctx.memory_controller.clone(), ctx.io_manager.clone(), ) .map_err(|e| anyhow!("Failed to init state store: {e}"))?; diff --git a/src/runtime/streaming/operators/joins/join_with_expiration.rs b/src/runtime/streaming/operators/joins/join_with_expiration.rs index 4d579715..87b838c1 100644 --- a/src/runtime/streaming/operators/joins/join_with_expiration.rs +++ b/src/runtime/streaming/operators/joins/join_with_expiration.rs @@ -242,7 +242,6 @@ impl Operator for JoinWithExpirationOperator { let store = OperatorStateStore::new( ctx.pipeline_id, ctx.state_dir.clone(), - ctx.memory_controller.clone(), ctx.io_manager.clone(), ) .map_err(|e| anyhow!("Failed to init state store: {e}"))?; diff --git a/src/runtime/streaming/operators/windows/session_aggregating_window.rs b/src/runtime/streaming/operators/windows/session_aggregating_window.rs index 15075964..789fc4af 100644 --- a/src/runtime/streaming/operators/windows/session_aggregating_window.rs +++ b/src/runtime/streaming/operators/windows/session_aggregating_window.rs @@ -730,7 +730,6 @@ impl Operator for SessionWindowOperator { let store = OperatorStateStore::new( ctx.pipeline_id, ctx.state_dir.clone(), - ctx.memory_controller.clone(), ctx.io_manager.clone(), ) .map_err(|e| anyhow!("Failed to init state store: {e}"))?; diff --git a/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs b/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs index 538e0dad..cf608ed4 100644 --- a/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs +++ b/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs @@ -348,7 +348,6 @@ impl Operator for SlidingWindowOperator { let store = OperatorStateStore::new( ctx.pipeline_id, ctx.state_dir.clone(), - ctx.memory_controller.clone(), ctx.io_manager.clone(), ) .map_err(|e| anyhow!("Failed to init state store: {e}"))?; diff --git a/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs b/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs index 7bf3268d..baa21bc5 100644 --- a/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs +++ b/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs @@ -166,7 +166,6 @@ impl Operator for TumblingWindowOperator { let store = OperatorStateStore::new( ctx.pipeline_id, ctx.state_dir.clone(), - ctx.memory_controller.clone(), ctx.io_manager.clone(), ) .map_err(|e| anyhow!("Failed to init state store: {e}"))?; diff --git a/src/runtime/streaming/operators/windows/window_function.rs b/src/runtime/streaming/operators/windows/window_function.rs index cf6a198d..b16c9a56 100644 --- a/src/runtime/streaming/operators/windows/window_function.rs +++ b/src/runtime/streaming/operators/windows/window_function.rs @@ -137,7 +137,6 @@ impl Operator for WindowFunctionOperator { let store = OperatorStateStore::new( ctx.pipeline_id, ctx.state_dir.clone(), - ctx.memory_controller.clone(), ctx.io_manager.clone(), ) .map_err(|e| anyhow!("Failed to init state store: {e}"))?; diff --git a/src/runtime/streaming/protocol/event.rs b/src/runtime/streaming/protocol/event.rs index 823035f8..21be6852 100644 --- a/src/runtime/streaming/protocol/event.rs +++ b/src/runtime/streaming/protocol/event.rs @@ -17,7 +17,7 @@ use std::time::SystemTime; use arrow_array::RecordBatch; -use crate::runtime::streaming::memory::MemoryTicket; +use crate::runtime::memory::MemoryTicket; #[derive(Debug, Copy, Clone, PartialEq, Eq, Encode, Decode, Serialize, Deserialize)] pub enum Watermark { diff --git a/src/runtime/streaming/state/metrics.rs b/src/runtime/streaming/state/metrics.rs index c6d5ae4e..4a86a64f 100644 --- a/src/runtime/streaming/state/metrics.rs +++ b/src/runtime/streaming/state/metrics.rs @@ -2,7 +2,7 @@ // you may not use this file except in compliance with the License. pub trait StateMetricsCollector: Send + Sync + 'static { - fn record_memory_usage(&self, operator_id: u32, bytes: usize); + fn record_memory_usage(&self, operator_id: u32, bytes: u64); fn record_spill_duration(&self, operator_id: u32, duration_ms: u128); fn record_compaction_duration(&self, operator_id: u32, is_major: bool, duration_ms: u128); fn inc_io_errors(&self, operator_id: u32); @@ -11,7 +11,7 @@ pub trait StateMetricsCollector: Send + Sync + 'static { /// Default no-op implementation. pub struct NoopMetricsCollector; impl StateMetricsCollector for NoopMetricsCollector { - fn record_memory_usage(&self, _: u32, _: usize) {} + fn record_memory_usage(&self, _: u32, _: u64) {} fn record_spill_duration(&self, _: u32, _: u128) {} fn record_compaction_duration(&self, _: u32, _: bool, _: u128) {} fn inc_io_errors(&self, _: u32) {} diff --git a/src/runtime/streaming/state/mod.rs b/src/runtime/streaming/state/mod.rs index ae14ad62..7d5bb3ef 100644 --- a/src/runtime/streaming/state/mod.rs +++ b/src/runtime/streaming/state/mod.rs @@ -22,4 +22,4 @@ pub use io_manager::{CompactJob, IoManager, IoPool, SpillJob}; #[allow(unused_imports)] pub use metrics::{NoopMetricsCollector, StateMetricsCollector}; #[allow(unused_imports)] -pub use operator_state::{MemoryController, OperatorStateStore}; +pub use operator_state::OperatorStateStore; diff --git a/src/runtime/streaming/state/operator_state.rs b/src/runtime/streaming/state/operator_state.rs index 84838ca2..68e111fb 100644 --- a/src/runtime/streaming/state/operator_state.rs +++ b/src/runtime/streaming/state/operator_state.rs @@ -4,6 +4,7 @@ use super::error::{Result, StateEngineError}; use super::io_manager::{CompactJob, IoManager, SpillJob}; use super::metrics::StateMetricsCollector; +use crate::runtime::memory::{global_state_memory_pool, MemoryPool}; use arrow_array::builder::{BinaryBuilder, BooleanBuilder, UInt64Builder}; use arrow_array::{Array, BinaryArray, RecordBatch, UInt64Array}; use arrow_schema::{DataType, Field, Schema}; @@ -16,7 +17,7 @@ use std::collections::{HashMap, HashSet, VecDeque}; use std::fs::{self, File}; use std::path::{Path, PathBuf}; use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use tokio::sync::Notify; use uuid::Uuid; @@ -28,38 +29,6 @@ pub type TombstoneMap = HashMap; const TOMBSTONE_ENTRY_OVERHEAD: usize = 8 + 16; -#[derive(Debug)] -pub struct MemoryController { - current_usage: AtomicUsize, - hard_limit: usize, - soft_limit: usize, -} - -impl MemoryController { - pub fn new(soft_limit: usize, hard_limit: usize) -> Arc { - Arc::new(Self { - current_usage: AtomicUsize::new(0), - hard_limit, - soft_limit, - }) - } - pub fn exceeds_hard_limit(&self, incoming: usize) -> bool { - self.current_usage.load(Ordering::Relaxed) + incoming > self.hard_limit - } - pub fn should_spill(&self) -> bool { - self.current_usage.load(Ordering::Relaxed) > self.soft_limit - } - pub fn record_inc(&self, bytes: usize) { - self.current_usage.fetch_add(bytes, Ordering::Relaxed); - } - pub fn record_dec(&self, bytes: usize) { - self.current_usage.fetch_sub(bytes, Ordering::Relaxed); - } - pub fn usage_bytes(&self) -> usize { - self.current_usage.load(Ordering::Relaxed) - } -} - pub struct OperatorStateStore { pub operator_id: u32, current_epoch: AtomicU64, @@ -71,7 +40,8 @@ pub struct OperatorStateStore { tombstone_files: RwLock>, tombstones: RwLock, - mem_ctrl: Arc, + state_block: Arc, + soft_limit: u64, io_manager: IoManager, data_dir: PathBuf, @@ -82,13 +52,14 @@ pub struct OperatorStateStore { is_compacting: AtomicBool, } +const DEFAULT_SOFT_LIMIT_RATIO: f64 = 0.7; + impl OperatorStateStore { - pub fn new( - operator_id: u32, - base_dir: impl AsRef, - mem_ctrl: Arc, - io_manager: IoManager, - ) -> Result> { + pub fn new(operator_id: u32, base_dir: impl AsRef, io_manager: IoManager) -> Result> { + let state_block = global_state_memory_pool(); + let (_, quota) = state_block.usage_metrics(); + let soft_limit = (quota as f64 * DEFAULT_SOFT_LIMIT_RATIO) as u64; + let op_dir = base_dir.as_ref().join(format!("op_{operator_id}")); let data_dir = op_dir.join("data"); let tombstone_dir = op_dir.join("tombstones"); @@ -104,7 +75,8 @@ impl OperatorStateStore { data_files: RwLock::new(Vec::new()), tombstone_files: RwLock::new(Vec::new()), tombstones: RwLock::new(HashMap::new()), - mem_ctrl, + state_block, + soft_limit, io_manager, data_dir, tombstone_dir, @@ -114,21 +86,30 @@ impl OperatorStateStore { })) } + fn state_exceeds_hard_limit(&self, incoming: usize) -> bool { + let (used, quota) = self.state_block.usage_metrics(); + used + incoming as u64 > quota + } + + fn state_should_spill(&self) -> bool { + self.state_block.usage_metrics().0 > self.soft_limit + } + pub async fn put(self: &Arc, key: PartitionKey, batch: RecordBatch) -> Result<()> { let size = batch.get_array_memory_size(); - while self.mem_ctrl.exceeds_hard_limit(size) { + while self.state_exceeds_hard_limit(size) { self.trigger_spill(); self.spill_notify.notified().await; } - self.mem_ctrl.record_inc(size); + self.state_block.force_reserve(size as u64); self.active_table .write() .entry(key) .or_default() .push(batch); - if self.mem_ctrl.should_spill() { + if self.state_should_spill() { self.downgrade_active_table(self.current_epoch.load(Ordering::Acquire)); self.trigger_spill(); } @@ -142,20 +123,20 @@ impl OperatorStateStore { { let mut tb_guard = self.tombstones.write(); if tb_guard.insert(key.clone(), current_ep).is_none() { - self.mem_ctrl.record_inc(tombstone_mem_size); + self.state_block.force_reserve(tombstone_mem_size as u64); } } if let Some(batches) = self.active_table.write().remove(&key) { let released: usize = batches.iter().map(|b| b.get_array_memory_size()).sum(); - self.mem_ctrl.record_dec(released); + self.state_block.force_release(released as u64); } let mut imm = self.immutable_tables.lock(); for (_, table) in imm.iter_mut() { if let Some(batches) = table.remove(&key) { let released: usize = batches.iter().map(|b| b.get_array_memory_size()).sum(); - self.mem_ctrl.record_dec(released); + self.state_block.force_release(released as u64); } } @@ -347,8 +328,8 @@ impl OperatorStateStore { self.tombstone_files.write().push(path); } - self.mem_ctrl.record_dec(size_to_release); - metrics.record_memory_usage(self.operator_id, self.mem_ctrl.usage_bytes()); + self.state_block.force_release(size_to_release as u64); + metrics.record_memory_usage(self.operator_id, self.state_block.usage_metrics().0); self.is_spilling.store(false, Ordering::SeqCst); self.spill_notify.notify_waiters(); @@ -441,8 +422,8 @@ impl OperatorStateStore { }); if memory_freed > 0 { - self.mem_ctrl.record_dec(memory_freed); - metrics.record_memory_usage(self.operator_id, self.mem_ctrl.usage_bytes()); + self.state_block.force_release(memory_freed as u64); + metrics.record_memory_usage(self.operator_id, self.state_block.usage_metrics().0); } } @@ -521,7 +502,7 @@ impl OperatorStateStore { for key in loaded_tombstones.keys() { total_tombstone_mem += key.len() + TOMBSTONE_ENTRY_OVERHEAD; } - self.mem_ctrl.record_inc(total_tombstone_mem); + self.state_block.force_reserve(total_tombstone_mem as u64); *self.tombstones.write() = loaded_tombstones.clone(); let data_paths = self.data_files.read().clone(); @@ -741,18 +722,31 @@ mod tests { .unwrap() } - fn setup() -> (TempDir, Arc, IoManager, IoPool) { + const TEST_OPERATOR_MEMORY: u64 = 2 * 1024 * 1024; + + fn ensure_global_state_pool() { + use std::sync::Once; + use crate::runtime::memory::{init_global_state_memory_pool, try_global_state_memory_pool}; + static INIT: Once = Once::new(); + INIT.call_once(|| { + if try_global_state_memory_pool().is_err() { + init_global_state_memory_pool(TEST_OPERATOR_MEMORY).expect("state pool init"); + } + }); + } + + fn setup() -> (TempDir, IoManager, IoPool) { + ensure_global_state_pool(); let tmp = TempDir::new().unwrap(); - let mem = MemoryController::new(1024 * 1024, 2 * 1024 * 1024); let metrics: Arc = Arc::new(NoopMetricsCollector); let (pool, mgr) = IoPool::try_new(1, 1, metrics).unwrap(); - (tmp, mem, mgr, pool) + (tmp, mgr, pool) } #[tokio::test] async fn test_put_and_get() { - let (tmp, mem, mgr, _pool) = setup(); - let store = OperatorStateStore::new(1, tmp.path(), mem, mgr).unwrap(); + let (tmp, mgr, _pool) = setup(); + let store = OperatorStateStore::new(1, tmp.path(), mgr).unwrap(); let key = b"key-a".to_vec(); let batch = make_batch(&[10, 20, 30]); @@ -770,8 +764,8 @@ mod tests { #[tokio::test] async fn test_multiple_puts_same_key() { - let (tmp, mem, mgr, _pool) = setup(); - let store = OperatorStateStore::new(1, tmp.path(), mem, mgr).unwrap(); + let (tmp, mgr, _pool) = setup(); + let store = OperatorStateStore::new(1, tmp.path(), mgr).unwrap(); let key = b"key-x".to_vec(); store.put(key.clone(), make_batch(&[1])).await.unwrap(); @@ -783,8 +777,8 @@ mod tests { #[tokio::test] async fn test_get_nonexistent_key() { - let (tmp, mem, mgr, _pool) = setup(); - let store = OperatorStateStore::new(1, tmp.path(), mem, mgr).unwrap(); + let (tmp, mgr, _pool) = setup(); + let store = OperatorStateStore::new(1, tmp.path(), mgr).unwrap(); let result = store.get_batches(b"no-such-key").await.unwrap(); assert!(result.is_empty()); @@ -792,8 +786,8 @@ mod tests { #[tokio::test] async fn test_remove_batches() { - let (tmp, mem, mgr, _pool) = setup(); - let store = OperatorStateStore::new(1, tmp.path(), mem, mgr).unwrap(); + let (tmp, mgr, _pool) = setup(); + let store = OperatorStateStore::new(1, tmp.path(), mgr).unwrap(); let key = b"key-del".to_vec(); store.put(key.clone(), make_batch(&[42])).await.unwrap(); @@ -806,8 +800,8 @@ mod tests { #[tokio::test] async fn test_remove_does_not_affect_other_keys() { - let (tmp, mem, mgr, _pool) = setup(); - let store = OperatorStateStore::new(1, tmp.path(), mem, mgr).unwrap(); + let (tmp, mgr, _pool) = setup(); + let store = OperatorStateStore::new(1, tmp.path(), mgr).unwrap(); let k1 = b"key-1".to_vec(); let k2 = b"key-2".to_vec(); @@ -822,8 +816,8 @@ mod tests { #[tokio::test] async fn test_snapshot_epoch_advances() { - let (tmp, mem, mgr, _pool) = setup(); - let store = OperatorStateStore::new(1, tmp.path(), mem, mgr).unwrap(); + let (tmp, mgr, _pool) = setup(); + let store = OperatorStateStore::new(1, tmp.path(), mgr).unwrap(); store.put(b"k".to_vec(), make_batch(&[1])).await.unwrap(); store.snapshot_epoch(5).unwrap(); @@ -833,8 +827,8 @@ mod tests { #[tokio::test] async fn test_data_survives_snapshot_via_spill() { - let (tmp, mem, mgr, _pool) = setup(); - let store = OperatorStateStore::new(1, tmp.path(), mem, mgr).unwrap(); + let (tmp, mgr, _pool) = setup(); + let store = OperatorStateStore::new(1, tmp.path(), mgr).unwrap(); let key = b"persist".to_vec(); store.put(key.clone(), make_batch(&[99])).await.unwrap(); @@ -853,8 +847,8 @@ mod tests { #[tokio::test] async fn test_tombstone_hides_immutable_data() { - let (tmp, mem, mgr, _pool) = setup(); - let store = OperatorStateStore::new(1, tmp.path(), mem, mgr).unwrap(); + let (tmp, mgr, _pool) = setup(); + let store = OperatorStateStore::new(1, tmp.path(), mgr).unwrap(); let key = b"will-die".to_vec(); store.put(key.clone(), make_batch(&[7])).await.unwrap(); @@ -871,30 +865,31 @@ mod tests { } #[tokio::test] - async fn test_memory_controller_tracking() { - let mem = MemoryController::new(1024, 2048); - assert_eq!(mem.usage_bytes(), 0); + async fn test_state_block_tracking() { + let mem = MemoryPool::new(2048); + assert_eq!(mem.usage_metrics().0, 0); - mem.record_inc(100); - assert_eq!(mem.usage_bytes(), 100); + mem.force_reserve(100); + assert_eq!(mem.usage_metrics().0, 100); - mem.record_dec(40); - assert_eq!(mem.usage_bytes(), 60); + mem.force_release(40); + assert_eq!(mem.usage_metrics().0, 60); - assert!(!mem.should_spill()); - mem.record_inc(1000); - assert!(mem.should_spill()); + let soft_limit = 1000u64; + assert!(mem.usage_metrics().0 <= soft_limit); + mem.force_reserve(1000); + assert!(mem.usage_metrics().0 > soft_limit); } #[tokio::test] - async fn test_memory_controller_hard_limit() { - let mem = MemoryController::new(512, 1024); - assert!(!mem.exceeds_hard_limit(500)); - assert!(mem.exceeds_hard_limit(1025)); - - mem.record_inc(800); - assert!(mem.exceeds_hard_limit(300)); - assert!(!mem.exceeds_hard_limit(200)); + async fn test_state_block_hard_limit() { + let mem = MemoryPool::new(1024); + assert!(mem.usage_metrics().0 + 500 <= mem.usage_metrics().1); + assert!(mem.usage_metrics().0 + 1025 > mem.usage_metrics().1); + + mem.force_reserve(800); + assert!(mem.usage_metrics().0 + 300 > mem.usage_metrics().1); + assert!(mem.usage_metrics().0 + 200 <= mem.usage_metrics().1); } #[test] diff --git a/src/runtime/buffer_and_event/buffer_or_event.rs b/src/runtime/wasm/buffer_and_event/buffer_or_event.rs similarity index 100% rename from src/runtime/buffer_and_event/buffer_or_event.rs rename to src/runtime/wasm/buffer_and_event/buffer_or_event.rs diff --git a/src/runtime/buffer_and_event/mod.rs b/src/runtime/wasm/buffer_and_event/mod.rs similarity index 100% rename from src/runtime/buffer_and_event/mod.rs rename to src/runtime/wasm/buffer_and_event/mod.rs diff --git a/src/runtime/buffer_and_event/stream_element/mod.rs b/src/runtime/wasm/buffer_and_event/stream_element/mod.rs similarity index 100% rename from src/runtime/buffer_and_event/stream_element/mod.rs rename to src/runtime/wasm/buffer_and_event/stream_element/mod.rs diff --git a/src/runtime/buffer_and_event/stream_element/stream_element.rs b/src/runtime/wasm/buffer_and_event/stream_element/stream_element.rs similarity index 100% rename from src/runtime/buffer_and_event/stream_element/stream_element.rs rename to src/runtime/wasm/buffer_and_event/stream_element/stream_element.rs diff --git a/src/runtime/wasm/input/input_protocol.rs b/src/runtime/wasm/input/input_protocol.rs index 69fae972..50294201 100644 --- a/src/runtime/wasm/input/input_protocol.rs +++ b/src/runtime/wasm/input/input_protocol.rs @@ -10,7 +10,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::runtime::buffer_and_event::BufferOrEvent; +use crate::runtime::wasm::buffer_and_event::BufferOrEvent; use std::time::Duration; pub trait InputProtocol: Send + Sync + 'static { diff --git a/src/runtime/wasm/input/input_provider.rs b/src/runtime/wasm/input/input_provider.rs index 3f6606cd..8eee649d 100644 --- a/src/runtime/wasm/input/input_provider.rs +++ b/src/runtime/wasm/input/input_provider.rs @@ -11,7 +11,7 @@ // limitations under the License. use crate::runtime::input::Input; -use crate::runtime::task::InputConfig; +use crate::runtime::wasm::task::InputConfig; pub struct InputProvider; diff --git a/src/runtime/wasm/input/input_runner.rs b/src/runtime/wasm/input/input_runner.rs index 854e4de8..ece85e3d 100644 --- a/src/runtime/wasm/input/input_runner.rs +++ b/src/runtime/wasm/input/input_runner.rs @@ -10,13 +10,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::runtime::buffer_and_event::BufferOrEvent; use crate::runtime::common::TaskCompletionFlag; use crate::runtime::input::input_protocol::InputProtocol; use crate::runtime::input::{Input, InputState}; use crate::runtime::processor::function_error::FunctionErrorReport; -use crate::runtime::task::ControlMailBox; -use crate::runtime::task::InputRuntimeConfig; +use crate::runtime::wasm::buffer_and_event::BufferOrEvent; +use crate::runtime::wasm::task::ControlMailBox; +use crate::runtime::wasm::task::InputRuntimeConfig; use crossbeam_channel::{Receiver, Sender, bounded, unbounded}; use std::sync::{Arc, Mutex}; use std::thread; @@ -250,7 +250,7 @@ impl InputRunner

{ impl Input for InputRunner

{ fn init_with_context( &mut self, - init_context: &crate::runtime::taskexecutor::InitContext, + init_context: &crate::runtime::wasm::taskexecutor::InitContext, ) -> Result<(), Box> { if !matches!(*self.state.lock().unwrap(), InputState::Uninitialized) { return Ok(()); diff --git a/src/runtime/wasm/input/interface.rs b/src/runtime/wasm/input/interface.rs index dd89ba77..06da4923 100644 --- a/src/runtime/wasm/input/interface.rs +++ b/src/runtime/wasm/input/interface.rs @@ -10,8 +10,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::runtime::buffer_and_event::BufferOrEvent; -use crate::runtime::taskexecutor::InitContext; +use crate::runtime::wasm::buffer_and_event::BufferOrEvent; +use crate::runtime::wasm::taskexecutor::InitContext; pub use crate::runtime::common::ComponentState as InputState; diff --git a/src/runtime/wasm/input/protocol/kafka/kafka_protocol.rs b/src/runtime/wasm/input/protocol/kafka/kafka_protocol.rs index 85336c53..1fb487a6 100644 --- a/src/runtime/wasm/input/protocol/kafka/kafka_protocol.rs +++ b/src/runtime/wasm/input/protocol/kafka/kafka_protocol.rs @@ -11,8 +11,8 @@ // limitations under the License. use super::config::KafkaConfig; -use crate::runtime::buffer_and_event::BufferOrEvent; use crate::runtime::input::input_protocol::InputProtocol; +use crate::runtime::wasm::buffer_and_event::BufferOrEvent; use rdkafka::Message; use rdkafka::TopicPartitionList; use rdkafka::config::ClientConfig; diff --git a/src/runtime/wasm/mod.rs b/src/runtime/wasm/mod.rs index b1c82f4c..78be72e2 100644 --- a/src/runtime/wasm/mod.rs +++ b/src/runtime/wasm/mod.rs @@ -13,6 +13,9 @@ //! WebAssembly runtime integration. +pub mod buffer_and_event; pub mod input; pub mod output; pub mod processor; +pub mod task; +pub mod taskexecutor; diff --git a/src/runtime/wasm/output/interface.rs b/src/runtime/wasm/output/interface.rs index e7c3b903..21c3055d 100644 --- a/src/runtime/wasm/output/interface.rs +++ b/src/runtime/wasm/output/interface.rs @@ -10,8 +10,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::runtime::buffer_and_event::BufferOrEvent; -use crate::runtime::taskexecutor::InitContext; +use crate::runtime::wasm::buffer_and_event::BufferOrEvent; +use crate::runtime::wasm::taskexecutor::InitContext; pub trait Output: Send + Sync { fn init_with_context( diff --git a/src/runtime/wasm/output/output_protocol.rs b/src/runtime/wasm/output/output_protocol.rs index dd502ca6..6140d3eb 100644 --- a/src/runtime/wasm/output/output_protocol.rs +++ b/src/runtime/wasm/output/output_protocol.rs @@ -10,7 +10,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::runtime::buffer_and_event::BufferOrEvent; +use crate::runtime::wasm::buffer_and_event::BufferOrEvent; pub trait OutputProtocol: Send + Sync + 'static { fn name(&self) -> String; diff --git a/src/runtime/wasm/output/output_provider.rs b/src/runtime/wasm/output/output_provider.rs index c6d01fef..25ca8431 100644 --- a/src/runtime/wasm/output/output_provider.rs +++ b/src/runtime/wasm/output/output_provider.rs @@ -11,7 +11,7 @@ // limitations under the License. use crate::runtime::output::Output; -use crate::runtime::task::OutputConfig; +use crate::runtime::wasm::task::OutputConfig; pub struct OutputProvider; diff --git a/src/runtime/wasm/output/output_runner.rs b/src/runtime/wasm/output/output_runner.rs index 85ba99b4..ca6d780c 100644 --- a/src/runtime/wasm/output/output_runner.rs +++ b/src/runtime/wasm/output/output_runner.rs @@ -10,13 +10,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::runtime::buffer_and_event::BufferOrEvent; use crate::runtime::common::{ComponentState, TaskCompletionFlag}; use crate::runtime::output::Output; use crate::runtime::output::output_protocol::OutputProtocol; use crate::runtime::processor::function_error::FunctionErrorReport; -use crate::runtime::task::ControlMailBox; -use crate::runtime::task::OutputRuntimeConfig; +use crate::runtime::wasm::buffer_and_event::BufferOrEvent; +use crate::runtime::wasm::task::ControlMailBox; +use crate::runtime::wasm::task::OutputRuntimeConfig; use crossbeam_channel::{Receiver, Sender, bounded, unbounded}; use std::sync::{Arc, Mutex}; use std::thread; @@ -288,7 +288,7 @@ impl OutputRunner

{ impl Output for OutputRunner

{ fn init_with_context( &mut self, - ctx: &crate::runtime::taskexecutor::InitContext, + ctx: &crate::runtime::wasm::taskexecutor::InitContext, ) -> Result<(), Box> { if !matches!(*self.state.lock().unwrap(), ComponentState::Uninitialized) { return Ok(()); diff --git a/src/runtime/wasm/output/protocol/kafka/kafka_protocol.rs b/src/runtime/wasm/output/protocol/kafka/kafka_protocol.rs index 2083294d..d9e6db4d 100644 --- a/src/runtime/wasm/output/protocol/kafka/kafka_protocol.rs +++ b/src/runtime/wasm/output/protocol/kafka/kafka_protocol.rs @@ -11,8 +11,8 @@ // limitations under the License. use super::producer_config::KafkaProducerConfig; -use crate::runtime::buffer_and_event::BufferOrEvent; use crate::runtime::output::output_protocol::OutputProtocol; +use crate::runtime::wasm::buffer_and_event::BufferOrEvent; use rdkafka::producer::{BaseRecord, DefaultProducerContext, Producer, ThreadedProducer}; use std::sync::Mutex; use std::time::Duration; diff --git a/src/runtime/wasm/processor/wasm/wasm_host.rs b/src/runtime/wasm/processor/wasm/wasm_host.rs index 009dd6b4..2bf7d4f0 100644 --- a/src/runtime/wasm/processor/wasm/wasm_host.rs +++ b/src/runtime/wasm/processor/wasm/wasm_host.rs @@ -10,9 +10,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::runtime::buffer_and_event::BufferOrEvent; use crate::runtime::output::Output; use crate::runtime::processor::wasm::wasm_cache; +use crate::runtime::wasm::buffer_and_event::BufferOrEvent; use crate::storage::state_backend::{StateStore, StateStoreFactory}; use std::sync::{Arc, OnceLock}; use wasmtime::component::{Component, HasData, Linker, Resource, bindgen}; @@ -449,7 +449,7 @@ pub fn create_wasm_host_with_component( engine: &Engine, component: &Component, outputs: Vec>, - init_context: &crate::runtime::taskexecutor::InitContext, + init_context: &crate::runtime::wasm::taskexecutor::InitContext, task_name: String, create_time: u64, ) -> anyhow::Result<(Processor, Store)> { @@ -495,7 +495,7 @@ pub fn create_wasm_host_with_component( pub fn create_wasm_host( wasm_bytes: &[u8], outputs: Vec>, - init_context: &crate::runtime::taskexecutor::InitContext, + init_context: &crate::runtime::wasm::taskexecutor::InitContext, task_name: String, create_time: u64, ) -> anyhow::Result<(Processor, Store)> { diff --git a/src/runtime/wasm/processor/wasm/wasm_processor.rs b/src/runtime/wasm/processor/wasm/wasm_processor.rs index 1afc9dcf..52234bfe 100644 --- a/src/runtime/wasm/processor/wasm/wasm_processor.rs +++ b/src/runtime/wasm/processor/wasm/wasm_processor.rs @@ -134,7 +134,7 @@ impl WasmProcessorImpl { impl WasmProcessor for WasmProcessorImpl { fn init_with_context( &mut self, - _init_context: &crate::runtime::taskexecutor::InitContext, + _init_context: &crate::runtime::wasm::taskexecutor::InitContext, ) -> Result<(), Box> { if self.initialized { log::warn!("WasmProcessor '{}' already initialized", self.name); @@ -405,7 +405,7 @@ impl WasmProcessor for WasmProcessorImpl { fn init_wasm_host( &mut self, outputs: Vec>, - init_context: &crate::runtime::taskexecutor::InitContext, + init_context: &crate::runtime::wasm::taskexecutor::InitContext, task_name: String, create_time: u64, ) -> Result<(), Box> { diff --git a/src/runtime/wasm/processor/wasm/wasm_processor_trait.rs b/src/runtime/wasm/processor/wasm/wasm_processor_trait.rs index 23a9f703..fb2c17fb 100644 --- a/src/runtime/wasm/processor/wasm/wasm_processor_trait.rs +++ b/src/runtime/wasm/processor/wasm/wasm_processor_trait.rs @@ -11,7 +11,7 @@ // limitations under the License. use crate::runtime::output::Output; -use crate::runtime::taskexecutor::InitContext; +use crate::runtime::wasm::taskexecutor::InitContext; pub trait WasmProcessor: Send + Sync { fn process( diff --git a/src/runtime/wasm/processor/wasm/wasm_task.rs b/src/runtime/wasm/processor/wasm/wasm_task.rs index c61f385f..4330aaaf 100644 --- a/src/runtime/wasm/processor/wasm/wasm_task.rs +++ b/src/runtime/wasm/processor/wasm/wasm_task.rs @@ -13,13 +13,13 @@ use super::input_strategy::{InputStrategy, RoundRobinStrategy, from_selector_name}; use super::thread_pool::ThreadGroup; use super::wasm_processor_trait::WasmProcessor; -use crate::runtime::buffer_and_event::BufferOrEvent; use crate::runtime::common::{ComponentState, TaskCompletionFlag}; use crate::runtime::input::Input; use crate::runtime::output::Output; use crate::runtime::processor::function_error::FunctionErrorReport; -use crate::runtime::task::ProcessorRuntimeConfig; -use crate::runtime::task::{ControlMailBox, TaskControlSignal, TaskLifecycle}; +use crate::runtime::wasm::buffer_and_event::BufferOrEvent; +use crate::runtime::wasm::task::ProcessorRuntimeConfig; +use crate::runtime::wasm::task::{ControlMailBox, TaskControlSignal, TaskLifecycle}; use crate::storage::task::FunctionInfo; use crossbeam_channel::{Receiver, after, select, unbounded}; use std::sync::atomic::{AtomicBool, Ordering}; @@ -120,7 +120,7 @@ impl WasmTask { pub fn init_with_context( &mut self, - init_context: &crate::runtime::taskexecutor::InitContext, + init_context: &crate::runtime::wasm::taskexecutor::InitContext, ) -> Result<(), Box> { let mut inputs = self.inputs.take().ok_or_else(|| { Box::new(std::io::Error::other("inputs already moved to thread")) @@ -262,7 +262,7 @@ impl WasmTask { shared_state: Arc>, failure_cause: Arc>>, execution_state: Arc>, - _init_context: crate::runtime::taskexecutor::InitContext, + _init_context: crate::runtime::wasm::taskexecutor::InitContext, ) { let mut state = TaskState::Initialized; let mut last_idx: usize = 0; @@ -729,7 +729,7 @@ impl WasmTask { impl TaskLifecycle for WasmTask { fn init_with_context( &mut self, - init_context: &crate::runtime::taskexecutor::InitContext, + init_context: &crate::runtime::wasm::taskexecutor::InitContext, ) -> Result<(), Box> { ::init_with_context(self, init_context) } diff --git a/src/runtime/task/builder/mod.rs b/src/runtime/wasm/task/builder/mod.rs similarity index 100% rename from src/runtime/task/builder/mod.rs rename to src/runtime/wasm/task/builder/mod.rs diff --git a/src/runtime/task/builder/processor/mod.rs b/src/runtime/wasm/task/builder/processor/mod.rs similarity index 97% rename from src/runtime/task/builder/processor/mod.rs rename to src/runtime/wasm/task/builder/processor/mod.rs index 418271dd..c1306924 100644 --- a/src/runtime/task/builder/processor/mod.rs +++ b/src/runtime/wasm/task/builder/processor/mod.rs @@ -19,8 +19,8 @@ use crate::runtime::output::{Output, OutputProvider}; use crate::runtime::processor::wasm::wasm_processor::WasmProcessorImpl; use crate::runtime::processor::wasm::wasm_processor_trait::WasmProcessor; use crate::runtime::processor::wasm::wasm_task::WasmTask; -use crate::runtime::task::yaml_keys::{TYPE, type_values}; -use crate::runtime::task::{InputConfig, OutputConfig, ProcessorConfig, WasmTaskConfig}; +use crate::runtime::wasm::task::yaml_keys::{TYPE, type_values}; +use crate::runtime::wasm::task::{InputConfig, OutputConfig, ProcessorConfig, WasmTaskConfig}; use serde_yaml::Value; use std::sync::Arc; diff --git a/src/runtime/task/builder/python/mod.rs b/src/runtime/wasm/task/builder/python/mod.rs similarity index 95% rename from src/runtime/task/builder/python/mod.rs rename to src/runtime/wasm/task/builder/python/mod.rs index 03f6ca0f..1b31d2e5 100644 --- a/src/runtime/task/builder/python/mod.rs +++ b/src/runtime/wasm/task/builder/python/mod.rs @@ -20,8 +20,8 @@ use crate::runtime::processor::python::get_python_engine_and_component; use crate::runtime::processor::wasm::wasm_processor::WasmProcessorImpl; use crate::runtime::processor::wasm::wasm_processor_trait::WasmProcessor; use crate::runtime::processor::wasm::wasm_task::WasmTask; -use crate::runtime::task::yaml_keys::{TYPE, type_values}; -use crate::runtime::task::{InputConfig, OutputConfig, ProcessorConfig, WasmTaskConfig}; +use crate::runtime::wasm::task::yaml_keys::{TYPE, type_values}; +use crate::runtime::wasm::task::{InputConfig, OutputConfig, ProcessorConfig, WasmTaskConfig}; use serde_yaml::Value; use std::sync::Arc; @@ -33,7 +33,7 @@ impl PythonBuilder { yaml_value: &Value, modules: &[(String, Vec)], create_time: u64, - ) -> Result, Box> + ) -> Result, Box> { let config_type = yaml_value .get(TYPE) diff --git a/src/runtime/task/builder/sink/mod.rs b/src/runtime/wasm/task/builder/sink/mod.rs similarity index 97% rename from src/runtime/task/builder/sink/mod.rs rename to src/runtime/wasm/task/builder/sink/mod.rs index f1babbd6..65e8bc95 100644 --- a/src/runtime/task/builder/sink/mod.rs +++ b/src/runtime/wasm/task/builder/sink/mod.rs @@ -15,7 +15,7 @@ // Specifically handles building logic for Sink type configuration (future support) use crate::runtime::processor::wasm::wasm_task::WasmTask; -use crate::runtime::task::yaml_keys::{TYPE, type_values}; +use crate::runtime::wasm::task::yaml_keys::{TYPE, type_values}; use serde_yaml::Value; use std::sync::Arc; diff --git a/src/runtime/task/builder/source/mod.rs b/src/runtime/wasm/task/builder/source/mod.rs similarity index 97% rename from src/runtime/task/builder/source/mod.rs rename to src/runtime/wasm/task/builder/source/mod.rs index d766ebbe..fc81bea9 100644 --- a/src/runtime/task/builder/source/mod.rs +++ b/src/runtime/wasm/task/builder/source/mod.rs @@ -15,7 +15,7 @@ // Specifically handles building logic for Source type configuration (future support) use crate::runtime::processor::wasm::wasm_task::WasmTask; -use crate::runtime::task::yaml_keys::{TYPE, type_values}; +use crate::runtime::wasm::task::yaml_keys::{TYPE, type_values}; use serde_yaml::Value; use std::sync::Arc; diff --git a/src/runtime/task/builder/task_builder.rs b/src/runtime/wasm/task/builder/task_builder.rs similarity index 94% rename from src/runtime/task/builder/task_builder.rs rename to src/runtime/wasm/task/builder/task_builder.rs index 9f89dbba..2246d6d8 100644 --- a/src/runtime/task/builder/task_builder.rs +++ b/src/runtime/wasm/task/builder/task_builder.rs @@ -15,13 +15,13 @@ //! Provides unified factory methods to create TaskLifecycle instances from YAML config. //! Dispatches to specific builders (Processor, Source, Sink, Python) based on task type. -use crate::runtime::task::TaskLifecycle; -use crate::runtime::task::builder::processor::ProcessorBuilder; +use crate::runtime::wasm::task::TaskLifecycle; +use crate::runtime::wasm::task::builder::processor::ProcessorBuilder; #[cfg(feature = "python")] -use crate::runtime::task::builder::python::PythonBuilder; -use crate::runtime::task::builder::sink::SinkBuilder; -use crate::runtime::task::builder::source::SourceBuilder; -use crate::runtime::task::yaml_keys::{NAME, TYPE, type_values}; +use crate::runtime::wasm::task::builder::python::PythonBuilder; +use crate::runtime::wasm::task::builder::sink::SinkBuilder; +use crate::runtime::wasm::task::builder::source::SourceBuilder; +use crate::runtime::wasm::task::yaml_keys::{NAME, TYPE, type_values}; use serde_yaml::Value; use std::sync::Arc; diff --git a/src/runtime/task/control_mailbox.rs b/src/runtime/wasm/task/control_mailbox.rs similarity index 100% rename from src/runtime/task/control_mailbox.rs rename to src/runtime/wasm/task/control_mailbox.rs diff --git a/src/runtime/task/lifecycle.rs b/src/runtime/wasm/task/lifecycle.rs similarity index 97% rename from src/runtime/task/lifecycle.rs rename to src/runtime/wasm/task/lifecycle.rs index 2b857f81..ea00f7c2 100644 --- a/src/runtime/task/lifecycle.rs +++ b/src/runtime/wasm/task/lifecycle.rs @@ -15,8 +15,8 @@ // Defines the complete lifecycle management interface for Task, including initialization, start, stop, checkpoint and close use crate::runtime::common::ComponentState; -use crate::runtime::task::control_mailbox::ControlMailBox; -use crate::runtime::taskexecutor::InitContext; +use crate::runtime::wasm::task::control_mailbox::ControlMailBox; +use crate::runtime::wasm::taskexecutor::InitContext; use crate::storage::task::FunctionInfo; use std::sync::Arc; diff --git a/src/runtime/task/mod.rs b/src/runtime/wasm/task/mod.rs similarity index 100% rename from src/runtime/task/mod.rs rename to src/runtime/wasm/task/mod.rs diff --git a/src/runtime/task/processor_config.rs b/src/runtime/wasm/task/processor_config.rs similarity index 99% rename from src/runtime/task/processor_config.rs rename to src/runtime/wasm/task/processor_config.rs index fe515647..a3069adc 100644 --- a/src/runtime/task/processor_config.rs +++ b/src/runtime/wasm/task/processor_config.rs @@ -608,7 +608,7 @@ impl WasmTaskConfig { task_name: String, value: &Value, ) -> Result> { - use crate::runtime::task::yaml_keys::{INPUT_GROUPS, INPUTS, NAME, OUTPUTS}; + use crate::runtime::wasm::task::yaml_keys::{INPUT_GROUPS, INPUTS, NAME, OUTPUTS}; // 1. Get name from config (if exists), otherwise use the passed task_name let config_name = value diff --git a/src/runtime/task/yaml_keys.rs b/src/runtime/wasm/task/yaml_keys.rs similarity index 100% rename from src/runtime/task/yaml_keys.rs rename to src/runtime/wasm/task/yaml_keys.rs diff --git a/src/runtime/taskexecutor/init_context.rs b/src/runtime/wasm/taskexecutor/init_context.rs similarity index 97% rename from src/runtime/taskexecutor/init_context.rs rename to src/runtime/wasm/taskexecutor/init_context.rs index 13ad5c81..fca44a32 100644 --- a/src/runtime/taskexecutor/init_context.rs +++ b/src/runtime/wasm/taskexecutor/init_context.rs @@ -15,7 +15,7 @@ // Provides various resources needed for task initialization, including state storage, task storage, thread pool, etc. use crate::runtime::processor::wasm::thread_pool::{TaskThreadPool, ThreadGroup}; -use crate::runtime::task::ControlMailBox; +use crate::runtime::wasm::task::ControlMailBox; use crate::storage::state_backend::StateStorageServer; use crate::storage::task::TaskStorage; use std::sync::{Arc, Mutex}; diff --git a/src/runtime/taskexecutor/mod.rs b/src/runtime/wasm/taskexecutor/mod.rs similarity index 100% rename from src/runtime/taskexecutor/mod.rs rename to src/runtime/wasm/taskexecutor/mod.rs diff --git a/src/runtime/taskexecutor/task_manager.rs b/src/runtime/wasm/taskexecutor/task_manager.rs similarity index 98% rename from src/runtime/taskexecutor/task_manager.rs rename to src/runtime/wasm/taskexecutor/task_manager.rs index f11997d5..897e0a3d 100644 --- a/src/runtime/taskexecutor/task_manager.rs +++ b/src/runtime/wasm/taskexecutor/task_manager.rs @@ -13,8 +13,8 @@ use crate::config::GlobalConfig; use crate::runtime::common::ComponentState; use crate::runtime::processor::wasm::thread_pool::{GlobalTaskThreadPool, TaskThreadPool}; -use crate::runtime::task::{TaskBuilder, TaskLifecycle}; -use crate::runtime::taskexecutor::init_context::InitContext; +use crate::runtime::wasm::task::{TaskBuilder, TaskLifecycle}; +use crate::runtime::wasm::taskexecutor::init_context::InitContext; use crate::storage::state_backend::StateStorageServer; use crate::storage::task::{ FunctionInfo, StoredTaskInfo, TaskModuleBytes, TaskStorage, TaskStorageFactory, diff --git a/src/server/initializer.rs b/src/server/initializer.rs index c1e11569..f8a789f8 100644 --- a/src/server/initializer.rs +++ b/src/server/initializer.rs @@ -96,6 +96,7 @@ pub fn build_core_registry() -> ComponentRegistry { let b = ComponentRegistryBuilder::new() .register("WasmCache", initialize_wasm_cache) .register("TaskManager", initialize_task_manager) + .register("GlobalMemoryPool", initialize_global_memory_pool) .register("JobManager", initialize_job_manager); #[cfg(feature = "python")] let b = b.register("PythonService", initialize_python_service); @@ -143,7 +144,7 @@ fn initialize_wasm_cache(config: &GlobalConfig) -> Result<()> { } fn initialize_task_manager(config: &GlobalConfig) -> Result<()> { - crate::runtime::taskexecutor::TaskManager::init(config) + crate::runtime::wasm::taskexecutor::TaskManager::init(config) .context("TaskManager service failed to start")?; Ok(()) } @@ -155,32 +156,94 @@ fn initialize_python_service(config: &GlobalConfig) -> Result<()> { Ok(()) } +// Streaming heap limits from config + host probe; shared by GlobalMemoryPool and JobManager. +fn resolve_streaming_memory_limits(config: &GlobalConfig) -> (u64, u64) { + use crate::config::system::system_memory_info; + + let mem_info = system_memory_info().ok(); + let total_physical = mem_info.as_ref().map(|m| m.total_physical).unwrap_or(0); + let auto_runtime_bytes = (total_physical as f64 * 0.8) as u64; + + let max_memory_bytes = config + .streaming + .max_memory_bytes + .unwrap_or(if auto_runtime_bytes > 0 { + auto_runtime_bytes + } else { + 256 * 1024 * 1024 + }); + + let per_operator_memory_bytes = config + .streaming + .per_operator_state_memory_bytes + .unwrap_or(64 * 1024 * 1024); + + (max_memory_bytes, per_operator_memory_bytes) +} + +// Singleton global memory pools (streaming + operator state); registered before JobManager. +fn initialize_global_memory_pool(config: &GlobalConfig) -> Result<()> { + use crate::config::system::system_memory_info; + + let mem_info = system_memory_info().ok(); + let total_physical = mem_info.as_ref().map(|m| m.total_physical).unwrap_or(0); + let avail_physical = mem_info.as_ref().map(|m| m.available_physical).unwrap_or(0); + let total_virtual = mem_info.as_ref().map(|m| m.total_virtual).unwrap_or(0); + let avail_virtual = mem_info.as_ref().map(|m| m.available_virtual).unwrap_or(0); + + let (max_memory_bytes, per_operator_memory_bytes) = resolve_streaming_memory_limits(config); + + info!( + total_physical_mb = total_physical / (1024 * 1024), + available_physical_mb = avail_physical / (1024 * 1024), + total_virtual_mb = total_virtual / (1024 * 1024), + available_virtual_mb = avail_virtual / (1024 * 1024), + runtime_memory_mb = max_memory_bytes / (1024 * 1024), + shared_state_memory_mb = per_operator_memory_bytes / (1024 * 1024), + "GlobalMemoryPool: streaming + operator state limits (singleton)" + ); + + crate::runtime::memory::init_global_memory_pool(max_memory_bytes) + .context("Global streaming memory pool initialization failed")?; + crate::runtime::memory::init_global_state_memory_pool(per_operator_memory_bytes) + .context("Global operator state memory pool initialization failed")?; + + info!("GlobalMemoryPool component initialized"); + Ok(()) +} + fn initialize_job_manager(config: &GlobalConfig) -> Result<()> { use crate::runtime::streaming::factory::OperatorFactory; use crate::runtime::streaming::factory::Registry; use crate::runtime::streaming::job::{JobManager, StateConfig}; use std::sync::Arc; + let (_, per_operator_memory_bytes) = resolve_streaming_memory_limits(config); + let registry = Arc::new(Registry::new()); let factory = Arc::new(OperatorFactory::new(registry)); - let max_memory_bytes = config - .streaming - .max_memory_bytes - .unwrap_or(256 * 1024 * 1024); let state_base_dir = std::env::temp_dir().join("function-stream").join("state"); - let state_config = StateConfig::default(); + let state_config = StateConfig { + per_operator_memory_bytes, + ..StateConfig::default() + }; - JobManager::init(factory, max_memory_bytes, state_base_dir, state_config) + JobManager::init(factory, state_base_dir, state_config) .context("JobManager service failed to start")?; Ok(()) } fn initialize_coordinator(_config: &GlobalConfig) -> Result<()> { - crate::runtime::taskexecutor::TaskManager::get() + crate::runtime::wasm::taskexecutor::TaskManager::get() .context("Dependency violation: Coordinator requires TaskManager")?; + crate::runtime::memory::try_global_memory_pool() + .context("Dependency violation: Coordinator requires GlobalMemoryPool")?; + crate::runtime::memory::try_global_state_memory_pool() + .context("Dependency violation: Coordinator requires GlobalMemoryPool (state sub-pool)")?; + crate::storage::stream_catalog::CatalogManager::global() .context("Dependency violation: Coordinator requires StreamCatalog")?; From 1456a4dc9734401f74d53e72c0365fd73e9e2e90 Mon Sep 17 00:00:00 2001 From: luoluoyuyu Date: Sun, 19 Apr 2026 17:13:28 +0800 Subject: [PATCH 12/26] update --- conf/config.yaml | 17 +- src/config/global_config.rs | 16 +- src/config/mod.rs | 4 +- src/runtime/memory/block.rs | 5 + src/runtime/memory/error.rs | 29 ++ src/runtime/memory/global.rs | 27 +- src/runtime/memory/mod.rs | 5 +- src/runtime/memory/pool.rs | 43 ++- src/runtime/streaming/api/context.rs | 13 +- src/runtime/streaming/job/job_manager.rs | 48 ++- .../grouping/incremental_aggregate.rs | 6 + .../operators/joins/join_instance.rs | 6 + .../operators/joins/join_with_expiration.rs | 6 + .../windows/session_aggregating_window.rs | 6 + .../windows/sliding_aggregating_window.rs | 6 + .../windows/tumbling_aggregating_window.rs | 6 + .../operators/windows/window_function.rs | 6 + src/runtime/streaming/state/error.rs | 5 + src/runtime/streaming/state/operator_state.rs | 284 +++++++++++++----- src/server/initializer.rs | 67 +---- src/server/memory_service.rs | 61 ++++ src/server/mod.rs | 1 + 22 files changed, 492 insertions(+), 175 deletions(-) create mode 100644 src/server/memory_service.rs diff --git a/conf/config.yaml b/conf/config.yaml index cfb71d02..ea0683f3 100644 --- a/conf/config.yaml +++ b/conf/config.yaml @@ -51,16 +51,13 @@ wasm: # Streaming Runtime Configuration streaming: - # Global memory pool size for the streaming runtime (network buffering, backpressure). - # When not set, auto-detected as 70% of physical memory. - # Fallback: 256 MiB if detection fails. - # max_memory_bytes: 268435456 - - # Memory budget per stateful operator (aggregation, join, window). - # Each operator gets its own independent memory controller with this limit. - # When exceeded, the operator spills state to disk automatically. - # Default: 67108864 (64 MiB) - per_operator_state_memory_bytes: 67108864 + # Bytes in the global memory pool for streaming execution: pipeline buffers, batch collect, + # backpressure. Omitted → 200 MiB. + streaming_runtime_memory_bytes: 209715200 + + # Per stateful operator (join / agg / window): in-memory state store cap; spill when exceeded. + # Omitted → 100 MiB. + operator_state_store_memory_bytes: 104857600 # State Storage Configuration # Used to store runtime state data for tasks diff --git a/src/config/global_config.rs b/src/config/global_config.rs index 90332c25..c1960ac9 100644 --- a/src/config/global_config.rs +++ b/src/config/global_config.rs @@ -19,11 +19,21 @@ use crate::config::python_config::PythonConfig; use crate::config::service_config::ServiceConfig; use crate::config::wasm_config::WasmConfig; +/// Default for [`StreamingConfig::streaming_runtime_memory_bytes`] when unset. **200 MiB.** +pub const DEFAULT_STREAMING_RUNTIME_MEMORY_BYTES: u64 = 200 * 1024 * 1024; + +/// Default for [`StreamingConfig::operator_state_store_memory_bytes`] when unset. **100 MiB.** +pub const DEFAULT_OPERATOR_STATE_STORE_MEMORY_BYTES: u64 = 100 * 1024 * 1024; + #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct StreamingConfig { - pub max_memory_bytes: Option, - /// Total bytes for the global operator-state [`MemoryPool`](crate::runtime::memory::MemoryPool) (all stores share this quota). - pub per_operator_state_memory_bytes: Option, + /// Bytes reserved in the global memory pool for streaming execution (pipeline buffers, + /// batch collect, backpressure). + #[serde(default)] + pub streaming_runtime_memory_bytes: Option, + /// Per stateful operator: in-memory state store cap before spill. + #[serde(default)] + pub operator_state_store_memory_bytes: Option, } #[derive(Debug, Clone, Serialize, Deserialize, Default)] diff --git a/src/config/mod.rs b/src/config/mod.rs index 489063e1..55490088 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -20,7 +20,9 @@ pub mod storage; pub mod system; pub mod wasm_config; -pub use global_config::GlobalConfig; +pub use global_config::{ + DEFAULT_OPERATOR_STATE_STORE_MEMORY_BYTES, DEFAULT_STREAMING_RUNTIME_MEMORY_BYTES, GlobalConfig, +}; pub use loader::load_global_config; pub use log_config::LogConfig; #[allow(unused_imports)] diff --git a/src/runtime/memory/block.rs b/src/runtime/memory/block.rs index 18f30de5..2940b3e3 100644 --- a/src/runtime/memory/block.rs +++ b/src/runtime/memory/block.rs @@ -61,6 +61,11 @@ impl MemoryBlock { self.available_bytes.load(Ordering::Relaxed) } + #[inline] + pub fn capacity(&self) -> u64 { + self.capacity + } + pub(crate) fn release_ticket(&self, bytes: u64) { if bytes > 0 { self.available_bytes.fetch_add(bytes, Ordering::Release); diff --git a/src/runtime/memory/error.rs b/src/runtime/memory/error.rs index a5d5152f..008d5c71 100644 --- a/src/runtime/memory/error.rs +++ b/src/runtime/memory/error.rs @@ -17,6 +17,7 @@ use std::fmt; pub enum MemoryError { AlreadyInitialized, Uninitialized, + OsAllocationFailed { bytes: u64 }, } impl fmt::Display for MemoryError { @@ -28,8 +29,36 @@ impl fmt::Display for MemoryError { MemoryError::Uninitialized => { write!(f, "Global memory pool is not initialized") } + MemoryError::OsAllocationFailed { bytes } => { + write!( + f, + "insufficient memory: failed to reserve {} bytes (virtual capacity for pool cap) from the OS allocator", + bytes + ) + } } } } impl std::error::Error for MemoryError {} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MemoryAllocationError { + InsufficientCapacity, + RequestLargerThanPool, +} + +impl fmt::Display for MemoryAllocationError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + MemoryAllocationError::InsufficientCapacity => { + write!(f, "Insufficient capacity in memory pool") + } + MemoryAllocationError::RequestLargerThanPool => { + write!(f, "Requested block exceeds memory pool maximum") + } + } + } +} + +impl std::error::Error for MemoryAllocationError {} diff --git a/src/runtime/memory/global.rs b/src/runtime/memory/global.rs index 97a38d1f..42920147 100644 --- a/src/runtime/memory/global.rs +++ b/src/runtime/memory/global.rs @@ -17,11 +17,11 @@ use super::error::MemoryError; use super::pool::MemoryPool; static GLOBAL_POOL: OnceLock> = OnceLock::new(); -static GLOBAL_STATE_POOL: OnceLock> = OnceLock::new(); pub fn init_global_memory_pool(max_bytes: u64) -> Result<(), MemoryError> { + let pool = MemoryPool::try_new(max_bytes)?; GLOBAL_POOL - .set(MemoryPool::new(max_bytes)) + .set(pool) .map_err(|_| MemoryError::AlreadyInitialized) } @@ -31,26 +31,9 @@ pub fn try_global_memory_pool() -> Result, MemoryError> { #[inline] pub fn global_memory_pool() -> Arc { - try_global_memory_pool().expect("Global streaming pool must be initialized before use") + try_global_memory_pool().expect("Global memory pool must be initialized before use") } -pub fn init_global_state_memory_pool(max_bytes: u64) -> Result<(), MemoryError> { - GLOBAL_STATE_POOL - .set(MemoryPool::new(max_bytes)) - .map_err(|_| MemoryError::AlreadyInitialized) -} - -pub fn try_global_state_memory_pool() -> Result, MemoryError> { - GLOBAL_STATE_POOL.get().cloned().ok_or(MemoryError::Uninitialized) -} - -#[inline] -pub fn global_state_memory_pool() -> Arc { - try_global_state_memory_pool().expect("Global state pool must be initialized before use") -} - -pub fn get_memory_metrics() -> (Option<(u64, u64)>, Option<(u64, u64)>) { - let stream_metrics = GLOBAL_POOL.get().map(|p| p.usage_metrics()); - let state_metrics = GLOBAL_STATE_POOL.get().map(|p| p.usage_metrics()); - (stream_metrics, state_metrics) +pub fn get_memory_metrics() -> Option<(u64, u64)> { + GLOBAL_POOL.get().map(|p| p.usage_metrics()) } diff --git a/src/runtime/memory/mod.rs b/src/runtime/memory/mod.rs index 5a50028d..01a917a7 100644 --- a/src/runtime/memory/mod.rs +++ b/src/runtime/memory/mod.rs @@ -22,11 +22,10 @@ pub mod ticket; #[allow(unused_imports)] pub use block::MemoryBlock; #[allow(unused_imports)] -pub use error::MemoryError; +pub use error::{MemoryAllocationError, MemoryError}; #[allow(unused_imports)] pub use global::{ - get_memory_metrics, global_memory_pool, global_state_memory_pool, init_global_memory_pool, - init_global_state_memory_pool, try_global_memory_pool, try_global_state_memory_pool, + get_memory_metrics, global_memory_pool, init_global_memory_pool, try_global_memory_pool, }; pub use pool::MemoryPool; pub use ticket::MemoryTicket; diff --git a/src/runtime/memory/pool.rs b/src/runtime/memory/pool.rs index 3261c9b0..98592d35 100644 --- a/src/runtime/memory/pool.rs +++ b/src/runtime/memory/pool.rs @@ -18,6 +18,7 @@ use tokio::sync::Notify; use tracing::{debug, warn}; use super::block::MemoryBlock; +use super::error::{MemoryAllocationError, MemoryError}; #[derive(Debug)] pub struct MemoryPool { @@ -28,19 +29,50 @@ pub struct MemoryPool { } impl MemoryPool { - pub fn new(max_bytes: u64) -> Arc { - Arc::new(Self { + pub fn try_new(max_bytes: u64) -> Result, MemoryError> { + if max_bytes > 0 { + let n = usize::try_from(max_bytes) + .map_err(|_| MemoryError::OsAllocationFailed { bytes: max_bytes })?; + let mut v = Vec::::new(); + v.try_reserve_exact(n) + .map_err(|_| MemoryError::OsAllocationFailed { bytes: max_bytes })?; + } + Ok(Arc::new(Self { max_bytes, used_bytes: AtomicU64::new(0), available_bytes: Mutex::new(max_bytes), notify: Notify::new(), - }) + })) + } + + pub fn new(max_bytes: u64) -> Arc { + Self::try_new(max_bytes).expect("MemoryPool::try_new failed") } pub fn usage_metrics(&self) -> (u64, u64) { (self.used_bytes.load(Ordering::Relaxed), self.max_bytes) } + pub fn try_request_block( + self: &Arc, + bytes: u64, + ) -> Result, MemoryAllocationError> { + if bytes == 0 { + return Ok(MemoryBlock::new(0, self.clone())); + } + if bytes > self.max_bytes { + return Err(MemoryAllocationError::RequestLargerThanPool); + } + let mut available = self.available_bytes.lock(); + if *available >= bytes { + *available -= bytes; + self.used_bytes.fetch_add(bytes, Ordering::Relaxed); + Ok(MemoryBlock::new(bytes, self.clone())) + } else { + Err(MemoryAllocationError::InsufficientCapacity) + } + } + pub async fn request_block(self: &Arc, bytes: u64) -> Arc { if bytes == 0 { return MemoryBlock::new(0, self.clone()); @@ -67,7 +99,10 @@ impl MemoryPool { } } - debug!(bytes = bytes, "Global backpressure engaged: waiting for memory..."); + debug!( + bytes = bytes, + "Global backpressure engaged: waiting for memory..." + ); self.notify.notified().await; } } diff --git a/src/runtime/streaming/api/context.rs b/src/runtime/streaming/api/context.rs index 4e5af1a8..b5b723f5 100644 --- a/src/runtime/streaming/api/context.rs +++ b/src/runtime/streaming/api/context.rs @@ -14,10 +14,10 @@ use std::path::PathBuf; use std::sync::Arc; use std::time::{Duration, SystemTime}; -use anyhow::{anyhow, Context, Result}; +use anyhow::{Context, Result, anyhow}; use arrow_array::RecordBatch; -use crate::runtime::memory::{get_array_memory_size, MemoryPool}; +use crate::runtime::memory::{MemoryBlock, MemoryPool, get_array_memory_size}; use crate::runtime::streaming::network::endpoint::PhysicalSender; use crate::runtime::streaming::protocol::event::{StreamEvent, TrackedEvent}; use crate::runtime::streaming::state::IoManager; @@ -67,6 +67,11 @@ pub struct TaskContext { pub state_dir: PathBuf, pub io_manager: IoManager, + /// Pipeline-wide slab from the global pool; each stateful operator sub-allocates a ticket. + pub pipeline_state_memory_block: Option>, + /// Bytes reserved per stateful operator from [`Self::pipeline_state_memory_block`]. + pub operator_state_memory_bytes: u64, + /// Last globally-committed safe epoch for crash recovery. safe_epoch: u64, } @@ -82,6 +87,8 @@ impl TaskContext { memory_pool: Arc, io_manager: IoManager, state_dir: PathBuf, + pipeline_state_memory_block: Option>, + operator_state_memory_bytes: u64, safe_epoch: u64, ) -> Self { let task_name = format!( @@ -101,6 +108,8 @@ impl TaskContext { config: TaskContextConfig::default(), state_dir, io_manager, + pipeline_state_memory_block, + operator_state_memory_bytes, safe_epoch, } } diff --git a/src/runtime/streaming/job/job_manager.rs b/src/runtime/streaming/job/job_manager.rs index 3b8aa0f8..457b2f23 100644 --- a/src/runtime/streaming/job/job_manager.rs +++ b/src/runtime/streaming/job/job_manager.rs @@ -12,6 +12,7 @@ use std::collections::{HashMap, HashSet}; use std::path::{Path, PathBuf}; +use std::str::FromStr; use std::sync::{Arc, Mutex, OnceLock, RwLock}; use std::time::Duration; @@ -23,6 +24,8 @@ use tracing::{debug, error, info, warn}; use protocol::function_stream_graph::{ChainedOperator, FsProgram}; +use crate::config::DEFAULT_OPERATOR_STATE_STORE_MEMORY_BYTES; +use crate::runtime::memory::global_memory_pool; use crate::runtime::streaming::api::context::TaskContext; use crate::runtime::streaming::api::operator::{ConstructedOperator, Operator}; use crate::runtime::streaming::api::source::SourceOperator; @@ -32,11 +35,11 @@ use crate::runtime::streaming::job::edge_manager::EdgeManager; use crate::runtime::streaming::job::models::{ PhysicalExecutionGraph, PhysicalPipeline, PipelineStatus, StreamingJobRollupStatus, }; -use crate::runtime::memory::global_memory_pool; use crate::runtime::streaming::network::endpoint::{BoxedEventStream, PhysicalSender}; use crate::runtime::streaming::protocol::control::{ControlCommand, JobMasterEvent, StopMode}; use crate::runtime::streaming::protocol::event::CheckpointBarrier; use crate::runtime::streaming::state::{IoManager, IoPool, NoopMetricsCollector}; +use crate::sql::logical_node::logical::OperatorName; use crate::storage::stream_catalog::CatalogManager; #[derive(Debug, Clone)] @@ -80,13 +83,36 @@ impl Default for StateConfig { max_background_compactions: 2, soft_limit_ratio: 0.7, checkpoint_interval_ms: 10_000, - per_operator_memory_bytes: 64 * 1024 * 1024, + per_operator_memory_bytes: DEFAULT_OPERATOR_STATE_STORE_MEMORY_BYTES, } } } static GLOBAL_JOB_MANAGER: OnceLock> = OnceLock::new(); +/// Operators that create an [`crate::runtime::streaming::state::OperatorStateStore`] at runtime. +fn pipeline_state_store_operator_count(operators: &[ChainedOperator]) -> usize { + operators + .iter() + .filter(|op| { + OperatorName::from_str(op.operator_name.as_str()) + .ok() + .is_some_and(|n| { + matches!( + n, + OperatorName::Join + | OperatorName::InstantJoin + | OperatorName::WindowFunction + | OperatorName::TumblingWindowAggregate + | OperatorName::SlidingWindowAggregate + | OperatorName::SessionWindowAggregate + | OperatorName::UpdatingAggregate + ) + }) + }) + .count() +} + pub struct JobManager { active_jobs: Arc>>, operator_factory: Arc, @@ -465,6 +491,22 @@ impl JobManager { let subtask_index = 0; let parallelism = 1; + + let per_op = self.state_config.per_operator_memory_bytes; + let n_state_ops = pipeline_state_store_operator_count(operators); + let pipeline_state_memory_block = if n_state_ops > 0 { + let bytes = per_op + .checked_mul(n_state_ops as u64) + .ok_or_else(|| anyhow!("pipeline state memory byte size overflow"))?; + Some( + global_memory_pool() + .try_request_block(bytes) + .map_err(|e| anyhow!("pipeline state memory reservation failed: {e}"))?, + ) + } else { + None + }; + let ctx = TaskContext::new( job_id.clone(), pipeline_id, @@ -474,6 +516,8 @@ impl JobManager { Arc::clone(&global_memory_pool()), self.io_manager_client.clone(), job_state_dir.to_path_buf(), + pipeline_state_memory_block, + per_op, recovery_epoch, ); diff --git a/src/runtime/streaming/operators/grouping/incremental_aggregate.rs b/src/runtime/streaming/operators/grouping/incremental_aggregate.rs index 0b8e9c79..ffa7e2f1 100644 --- a/src/runtime/streaming/operators/grouping/incremental_aggregate.rs +++ b/src/runtime/streaming/operators/grouping/incremental_aggregate.rs @@ -722,10 +722,16 @@ impl Operator for IncrementalAggregatingFunc { } async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { + let pipeline_block = ctx + .pipeline_state_memory_block + .as_ref() + .ok_or_else(|| anyhow!("missing pipeline state memory block"))?; let store = OperatorStateStore::new( ctx.pipeline_id, ctx.state_dir.clone(), ctx.io_manager.clone(), + Arc::clone(pipeline_block), + ctx.operator_state_memory_bytes, ) .map_err(|e| anyhow!("Failed to init state store: {e}"))?; diff --git a/src/runtime/streaming/operators/joins/join_instance.rs b/src/runtime/streaming/operators/joins/join_instance.rs index e8474494..a6f4a53f 100644 --- a/src/runtime/streaming/operators/joins/join_instance.rs +++ b/src/runtime/streaming/operators/joins/join_instance.rs @@ -212,10 +212,16 @@ impl Operator for InstantJoinOperator { } async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { + let pipeline_block = ctx + .pipeline_state_memory_block + .as_ref() + .ok_or_else(|| anyhow!("missing pipeline state memory block"))?; let store = OperatorStateStore::new( ctx.pipeline_id, ctx.state_dir.clone(), ctx.io_manager.clone(), + Arc::clone(pipeline_block), + ctx.operator_state_memory_bytes, ) .map_err(|e| anyhow!("Failed to init state store: {e}"))?; diff --git a/src/runtime/streaming/operators/joins/join_with_expiration.rs b/src/runtime/streaming/operators/joins/join_with_expiration.rs index 87b838c1..e044f242 100644 --- a/src/runtime/streaming/operators/joins/join_with_expiration.rs +++ b/src/runtime/streaming/operators/joins/join_with_expiration.rs @@ -239,10 +239,16 @@ impl Operator for JoinWithExpirationOperator { } async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { + let pipeline_block = ctx + .pipeline_state_memory_block + .as_ref() + .ok_or_else(|| anyhow!("missing pipeline state memory block"))?; let store = OperatorStateStore::new( ctx.pipeline_id, ctx.state_dir.clone(), ctx.io_manager.clone(), + Arc::clone(pipeline_block), + ctx.operator_state_memory_bytes, ) .map_err(|e| anyhow!("Failed to init state store: {e}"))?; diff --git a/src/runtime/streaming/operators/windows/session_aggregating_window.rs b/src/runtime/streaming/operators/windows/session_aggregating_window.rs index 789fc4af..37be0b04 100644 --- a/src/runtime/streaming/operators/windows/session_aggregating_window.rs +++ b/src/runtime/streaming/operators/windows/session_aggregating_window.rs @@ -727,10 +727,16 @@ impl Operator for SessionWindowOperator { // Recovery & event sourcing: rebuild in-memory sessions from LSM-Tree async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { + let pipeline_block = ctx + .pipeline_state_memory_block + .as_ref() + .ok_or_else(|| anyhow!("missing pipeline state memory block"))?; let store = OperatorStateStore::new( ctx.pipeline_id, ctx.state_dir.clone(), ctx.io_manager.clone(), + Arc::clone(pipeline_block), + ctx.operator_state_memory_bytes, ) .map_err(|e| anyhow!("Failed to init state store: {e}"))?; diff --git a/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs b/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs index cf608ed4..02666c03 100644 --- a/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs +++ b/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs @@ -345,10 +345,16 @@ impl Operator for SlidingWindowOperator { // Recovery: restore dual-layer state (partial panes + raw active bins) async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { + let pipeline_block = ctx + .pipeline_state_memory_block + .as_ref() + .ok_or_else(|| anyhow!("missing pipeline state memory block"))?; let store = OperatorStateStore::new( ctx.pipeline_id, ctx.state_dir.clone(), ctx.io_manager.clone(), + Arc::clone(pipeline_block), + ctx.operator_state_memory_bytes, ) .map_err(|e| anyhow!("Failed to init state store: {e}"))?; diff --git a/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs b/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs index baa21bc5..f4a17fd1 100644 --- a/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs +++ b/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs @@ -163,10 +163,16 @@ impl Operator for TumblingWindowOperator { // Recovery: replay raw data from LSM-Tree into DataFusion sessions async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { + let pipeline_block = ctx + .pipeline_state_memory_block + .as_ref() + .ok_or_else(|| anyhow!("missing pipeline state memory block"))?; let store = OperatorStateStore::new( ctx.pipeline_id, ctx.state_dir.clone(), ctx.io_manager.clone(), + Arc::clone(pipeline_block), + ctx.operator_state_memory_bytes, ) .map_err(|e| anyhow!("Failed to init state store: {e}"))?; diff --git a/src/runtime/streaming/operators/windows/window_function.rs b/src/runtime/streaming/operators/windows/window_function.rs index b16c9a56..a379cd2d 100644 --- a/src/runtime/streaming/operators/windows/window_function.rs +++ b/src/runtime/streaming/operators/windows/window_function.rs @@ -134,10 +134,16 @@ impl Operator for WindowFunctionOperator { // Recovery: restore the lightweight timestamp index from LSM-Tree. // Data stays on disk until process_watermark triggers on-demand compute. async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { + let pipeline_block = ctx + .pipeline_state_memory_block + .as_ref() + .ok_or_else(|| anyhow!("missing pipeline state memory block"))?; let store = OperatorStateStore::new( ctx.pipeline_id, ctx.state_dir.clone(), ctx.io_manager.clone(), + Arc::clone(pipeline_block), + ctx.operator_state_memory_bytes, ) .map_err(|e| anyhow!("Failed to init state store: {e}"))?; diff --git a/src/runtime/streaming/state/error.rs b/src/runtime/streaming/state/error.rs index 10c3c7c5..81ca7e6c 100644 --- a/src/runtime/streaming/state/error.rs +++ b/src/runtime/streaming/state/error.rs @@ -4,6 +4,8 @@ use crossbeam_channel::TrySendError; use thiserror::Error; +use crate::runtime::memory::MemoryAllocationError; + #[derive(Error, Debug)] pub enum StateEngineError { #[error("I/O error during state persistence: {0}")] @@ -23,6 +25,9 @@ pub enum StateEngineError { #[error("State metadata corrupted: {0}")] Corruption(String), + + #[error("State memory block reservation failed: {0}")] + MemoryReservation(#[from] MemoryAllocationError), } pub type Result = std::result::Result; diff --git a/src/runtime/streaming/state/operator_state.rs b/src/runtime/streaming/state/operator_state.rs index 68e111fb..35b8a3d6 100644 --- a/src/runtime/streaming/state/operator_state.rs +++ b/src/runtime/streaming/state/operator_state.rs @@ -4,7 +4,7 @@ use super::error::{Result, StateEngineError}; use super::io_manager::{CompactJob, IoManager, SpillJob}; use super::metrics::StateMetricsCollector; -use crate::runtime::memory::{global_state_memory_pool, MemoryPool}; +use crate::runtime::memory::{MemoryBlock, MemoryTicket}; use arrow_array::builder::{BinaryBuilder, BooleanBuilder, UInt64Builder}; use arrow_array::{Array, BinaryArray, RecordBatch, UInt64Array}; use arrow_schema::{DataType, Field, Schema}; @@ -18,6 +18,7 @@ use std::fs::{self, File}; use std::path::{Path, PathBuf}; use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::time::{Duration, Instant}; use tokio::sync::Notify; use uuid::Uuid; @@ -40,7 +41,9 @@ pub struct OperatorStateStore { tombstone_files: RwLock>, tombstones: RwLock, - state_block: Arc, + state_ticket: Arc, + state_used: AtomicU64, + state_quota: u64, soft_limit: u64, io_manager: IoManager, @@ -55,10 +58,24 @@ pub struct OperatorStateStore { const DEFAULT_SOFT_LIMIT_RATIO: f64 = 0.7; impl OperatorStateStore { - pub fn new(operator_id: u32, base_dir: impl AsRef, io_manager: IoManager) -> Result> { - let state_block = global_state_memory_pool(); - let (_, quota) = state_block.usage_metrics(); - let soft_limit = (quota as f64 * DEFAULT_SOFT_LIMIT_RATIO) as u64; + /// `pipeline_state_memory_block` is the pipeline-wide slab reserved at job spawn; this store + /// takes one ticket of `operator_state_memory_bytes` from it. + pub fn new( + operator_id: u32, + base_dir: impl AsRef, + io_manager: IoManager, + pipeline_state_memory_block: Arc, + operator_state_memory_bytes: u64, + ) -> Result> { + let ticket = pipeline_state_memory_block + .try_allocate(operator_state_memory_bytes) + .ok_or_else(|| { + StateEngineError::Corruption( + "pipeline state memory block exhausted (operator state ticket)".into(), + ) + })?; + let state_ticket = Arc::new(ticket); + let soft_limit = (operator_state_memory_bytes as f64 * DEFAULT_SOFT_LIMIT_RATIO) as u64; let op_dir = base_dir.as_ref().join(format!("op_{operator_id}")); let data_dir = op_dir.join("data"); @@ -75,7 +92,9 @@ impl OperatorStateStore { data_files: RwLock::new(Vec::new()), tombstone_files: RwLock::new(Vec::new()), tombstones: RwLock::new(HashMap::new()), - state_block, + state_ticket, + state_used: AtomicU64::new(0), + state_quota: operator_state_memory_bytes, soft_limit, io_manager, data_dir, @@ -86,23 +105,60 @@ impl OperatorStateStore { })) } - fn state_exceeds_hard_limit(&self, incoming: usize) -> bool { - let (used, quota) = self.state_block.usage_metrics(); - used + incoming as u64 > quota + fn state_bytes_used(&self) -> u64 { + self.state_used.load(Ordering::Relaxed) } fn state_should_spill(&self) -> bool { - self.state_block.usage_metrics().0 > self.soft_limit + self.state_bytes_used() > self.soft_limit } - pub async fn put(self: &Arc, key: PartitionKey, batch: RecordBatch) -> Result<()> { - let size = batch.get_array_memory_size(); - while self.state_exceeds_hard_limit(size) { + fn rebuild_state_used_from_tables(&self) { + let mut n = 0u64; + for rows in self.active_table.read().values() { + for b in rows { + n += b.get_array_memory_size() as u64; + } + } + for (_, table) in self.immutable_tables.lock().iter() { + for rows in table.values() { + for b in rows { + n += b.get_array_memory_size() as u64; + } + } + } + self.state_used.store(n, Ordering::Release); + } + + async fn wait_until_memory_available_async(self: Arc, need: u64) { + while self.state_used.load(Ordering::Relaxed).saturating_add(need) > self.state_quota { self.trigger_spill(); self.spill_notify.notified().await; } + } + + fn wait_until_memory_available_blocking(self: &Arc, need: u64) -> Result<()> { + loop { + if self.state_used.load(Ordering::Relaxed).saturating_add(need) <= self.state_quota { + return Ok(()); + } + self.trigger_spill(); + let start = Instant::now(); + while self.is_spilling.load(Ordering::SeqCst) { + if start.elapsed() > Duration::from_secs(120) { + return Err(StateEngineError::Corruption( + "state memory wait for spill timed out".into(), + )); + } + std::thread::sleep(Duration::from_millis(1)); + } + } + } - self.state_block.force_reserve(size as u64); + pub async fn put(self: &Arc, key: PartitionKey, batch: RecordBatch) -> Result<()> { + let size = batch.get_array_memory_size() as u64; + self.clone().wait_until_memory_available_async(size).await; + self.state_used.fetch_add(size, Ordering::Relaxed); self.active_table .write() .entry(key) @@ -116,28 +172,40 @@ impl OperatorStateStore { Ok(()) } - pub fn remove_batches(&self, key: PartitionKey) -> Result<()> { + pub fn remove_batches(self: &Arc, key: PartitionKey) -> Result<()> { let current_ep = self.current_epoch.load(Ordering::Acquire); - let tombstone_mem_size = key.len() + TOMBSTONE_ENTRY_OVERHEAD; + let tombstone_mem_size = (key.len() + TOMBSTONE_ENTRY_OVERHEAD) as u64; { let mut tb_guard = self.tombstones.write(); - if tb_guard.insert(key.clone(), current_ep).is_none() { - self.state_block.force_reserve(tombstone_mem_size as u64); + if !tb_guard.contains_key(&key) { + self.wait_until_memory_available_blocking(tombstone_mem_size)?; + self.state_used + .fetch_add(tombstone_mem_size, Ordering::Relaxed); + tb_guard.insert(key.clone(), current_ep); } } - if let Some(batches) = self.active_table.write().remove(&key) { - let released: usize = batches.iter().map(|b| b.get_array_memory_size()).sum(); - self.state_block.force_release(released as u64); + let released_active: u64 = self + .active_table + .write() + .remove(&key) + .map(|rows| rows.iter().map(|b| b.get_array_memory_size() as u64).sum()) + .unwrap_or(0); + + let mut released_imm = 0u64; + for (_, table) in self.immutable_tables.lock().iter_mut() { + if let Some(rows) = table.remove(&key) { + released_imm += rows + .iter() + .map(|b| b.get_array_memory_size() as u64) + .sum::(); + } } - let mut imm = self.immutable_tables.lock(); - for (_, table) in imm.iter_mut() { - if let Some(batches) = table.remove(&key) { - let released: usize = batches.iter().map(|b| b.get_array_memory_size()).sum(); - self.state_block.force_release(released as u64); - } + let released = released_active.saturating_add(released_imm); + if released > 0 { + self.state_used.fetch_sub(released, Ordering::Relaxed); } Ok(()) @@ -271,12 +339,12 @@ impl OperatorStateStore { metrics: &Arc, ) -> Result<()> { let mut batches_to_write = Vec::new(); - let mut size_to_release: usize = 0; + let mut spilled_bytes: u64 = 0; let distinct_keys_count = data.len() as u64; for (key, batches) in data { for batch in batches { - size_to_release += batch.get_array_memory_size(); + spilled_bytes += batch.get_array_memory_size() as u64; batches_to_write.push(inject_partition_key(&batch, &key)?); } } @@ -289,6 +357,7 @@ impl OperatorStateStore { metrics.inc_io_errors(self.operator_id); let restored = restore_memtable_from_injected_batches(batches_to_write)?; self.immutable_tables.lock().push_front((epoch, restored)); + self.rebuild_state_used_from_tables(); self.is_spilling.store(false, Ordering::SeqCst); self.spill_notify.notify_waiters(); return Err(e); @@ -328,8 +397,11 @@ impl OperatorStateStore { self.tombstone_files.write().push(path); } - self.state_block.force_release(size_to_release as u64); - metrics.record_memory_usage(self.operator_id, self.state_block.usage_metrics().0); + if spilled_bytes > 0 { + self.state_used.fetch_sub(spilled_bytes, Ordering::Relaxed); + } + + metrics.record_memory_usage(self.operator_id, self.state_bytes_used()); self.is_spilling.store(false, Ordering::SeqCst); self.spill_notify.notify_waiters(); @@ -407,24 +479,20 @@ impl OperatorStateStore { let _ = fs::remove_file(path); } - // Watermark GC { let mut tg = self.tombstones.write(); - let mut memory_freed = 0; - - tg.retain(|key, deleted_epoch| { - if *deleted_epoch <= compacted_watermark_epoch { - memory_freed += key.len() + TOMBSTONE_ENTRY_OVERHEAD; - false - } else { - true + let keys_before: Vec = tg.keys().cloned().collect(); + tg.retain(|_key, deleted_epoch| *deleted_epoch > compacted_watermark_epoch); + let mut tomb_freed = 0u64; + for k in keys_before { + if !tg.contains_key(&k) { + tomb_freed += (k.len() + TOMBSTONE_ENTRY_OVERHEAD) as u64; } - }); - - if memory_freed > 0 { - self.state_block.force_release(memory_freed as u64); - metrics.record_memory_usage(self.operator_id, self.state_block.usage_metrics().0); } + if tomb_freed > 0 { + self.state_used.fetch_sub(tomb_freed, Ordering::Relaxed); + } + metrics.record_memory_usage(self.operator_id, self.state_bytes_used()); } { @@ -445,7 +513,11 @@ impl OperatorStateStore { result } - pub async fn restore_metadata(&self, safe_epoch: u64) -> Result> { + pub async fn restore_metadata( + self: &Arc, + safe_epoch: u64, + ) -> Result> { + self.state_used.store(0, Ordering::Release); self.active_table.write().clear(); self.immutable_tables .lock() @@ -465,8 +537,9 @@ impl OperatorStateStore { cleanup_future(&mut self.tombstone_files.write()); let tomb_paths = self.tombstone_files.read().clone(); - let loaded_tombstones = tokio::task::spawn_blocking(move || -> Result { - let mut map = HashMap::new(); + type RawTombstones = HashMap; + let raw_tombstones = tokio::task::spawn_blocking(move || -> Result { + let mut map = RawTombstones::new(); for path in tomb_paths { let file = File::open(&path).map_err(StateEngineError::IoError)?; let reader = ParquetRecordBatchReaderBuilder::try_new(file)?.build()?; @@ -498,12 +571,19 @@ impl OperatorStateStore { .await .map_err(|_| StateEngineError::Corruption("Task Panicked".into()))??; - let mut total_tombstone_mem = 0; - for key in loaded_tombstones.keys() { - total_tombstone_mem += key.len() + TOMBSTONE_ENTRY_OVERHEAD; + let tomb_epoch_map = raw_tombstones.clone(); + + *self.tombstones.write() = raw_tombstones; + self.rebuild_state_used_from_tables(); + let tomb_overhead: u64 = self + .tombstones + .read() + .keys() + .map(|k| (k.len() + TOMBSTONE_ENTRY_OVERHEAD) as u64) + .sum(); + if tomb_overhead > 0 { + self.state_used.fetch_add(tomb_overhead, Ordering::Relaxed); } - self.state_block.force_reserve(total_tombstone_mem as u64); - *self.tombstones.write() = loaded_tombstones.clone(); let data_paths = self.data_files.read().clone(); let active_keys = tokio::task::spawn_blocking(move || -> Result> { @@ -525,7 +605,7 @@ impl OperatorStateStore { .unwrap(); for i in 0..key_col.len() { let k = key_col.value(i).to_vec(); - let is_active = match loaded_tombstones.get(&k) { + let is_active = match tomb_epoch_map.get(&k) { Some(del_ep) => *del_ep < file_epoch, None => true, }; @@ -693,7 +773,8 @@ fn restore_memtable_from_injected_batches(batches: Vec) -> Result = (0..batch.num_columns()).collect(); proj.retain(|&i| i != idx); - m.entry(pk).or_default().push(batch.project(&proj)?); + let projected = batch.project(&proj)?; + m.entry(pk).or_default().push(projected); } Ok(m) } @@ -703,6 +784,7 @@ mod tests { use super::super::io_manager::IoPool; use super::super::metrics::NoopMetricsCollector; use super::*; + use crate::runtime::memory::{MemoryBlock, MemoryPool, global_memory_pool}; use arrow_array::Int64Array; use tempfile::TempDir; @@ -724,19 +806,27 @@ mod tests { const TEST_OPERATOR_MEMORY: u64 = 2 * 1024 * 1024; - fn ensure_global_state_pool() { + fn ensure_global_memory_pool() { + use crate::runtime::memory::{init_global_memory_pool, try_global_memory_pool}; use std::sync::Once; - use crate::runtime::memory::{init_global_state_memory_pool, try_global_state_memory_pool}; static INIT: Once = Once::new(); INIT.call_once(|| { - if try_global_state_memory_pool().is_err() { - init_global_state_memory_pool(TEST_OPERATOR_MEMORY).expect("state pool init"); + if try_global_memory_pool().is_err() { + init_global_memory_pool(TEST_OPERATOR_MEMORY.saturating_mul(64)) + .expect("global memory pool init"); } }); } + fn state_block(bytes: u64) -> Arc { + ensure_global_memory_pool(); + global_memory_pool() + .try_request_block(bytes) + .expect("test pipeline state memory block") + } + fn setup() -> (TempDir, IoManager, IoPool) { - ensure_global_state_pool(); + ensure_global_memory_pool(); let tmp = TempDir::new().unwrap(); let metrics: Arc = Arc::new(NoopMetricsCollector); let (pool, mgr) = IoPool::try_new(1, 1, metrics).unwrap(); @@ -746,7 +836,14 @@ mod tests { #[tokio::test] async fn test_put_and_get() { let (tmp, mgr, _pool) = setup(); - let store = OperatorStateStore::new(1, tmp.path(), mgr).unwrap(); + let store = OperatorStateStore::new( + 1, + tmp.path(), + mgr, + state_block(TEST_OPERATOR_MEMORY), + TEST_OPERATOR_MEMORY, + ) + .unwrap(); let key = b"key-a".to_vec(); let batch = make_batch(&[10, 20, 30]); @@ -765,7 +862,14 @@ mod tests { #[tokio::test] async fn test_multiple_puts_same_key() { let (tmp, mgr, _pool) = setup(); - let store = OperatorStateStore::new(1, tmp.path(), mgr).unwrap(); + let store = OperatorStateStore::new( + 1, + tmp.path(), + mgr, + state_block(TEST_OPERATOR_MEMORY), + TEST_OPERATOR_MEMORY, + ) + .unwrap(); let key = b"key-x".to_vec(); store.put(key.clone(), make_batch(&[1])).await.unwrap(); @@ -778,7 +882,14 @@ mod tests { #[tokio::test] async fn test_get_nonexistent_key() { let (tmp, mgr, _pool) = setup(); - let store = OperatorStateStore::new(1, tmp.path(), mgr).unwrap(); + let store = OperatorStateStore::new( + 1, + tmp.path(), + mgr, + state_block(TEST_OPERATOR_MEMORY), + TEST_OPERATOR_MEMORY, + ) + .unwrap(); let result = store.get_batches(b"no-such-key").await.unwrap(); assert!(result.is_empty()); @@ -787,7 +898,14 @@ mod tests { #[tokio::test] async fn test_remove_batches() { let (tmp, mgr, _pool) = setup(); - let store = OperatorStateStore::new(1, tmp.path(), mgr).unwrap(); + let store = OperatorStateStore::new( + 1, + tmp.path(), + mgr, + state_block(TEST_OPERATOR_MEMORY), + TEST_OPERATOR_MEMORY, + ) + .unwrap(); let key = b"key-del".to_vec(); store.put(key.clone(), make_batch(&[42])).await.unwrap(); @@ -801,7 +919,14 @@ mod tests { #[tokio::test] async fn test_remove_does_not_affect_other_keys() { let (tmp, mgr, _pool) = setup(); - let store = OperatorStateStore::new(1, tmp.path(), mgr).unwrap(); + let store = OperatorStateStore::new( + 1, + tmp.path(), + mgr, + state_block(TEST_OPERATOR_MEMORY), + TEST_OPERATOR_MEMORY, + ) + .unwrap(); let k1 = b"key-1".to_vec(); let k2 = b"key-2".to_vec(); @@ -817,7 +942,14 @@ mod tests { #[tokio::test] async fn test_snapshot_epoch_advances() { let (tmp, mgr, _pool) = setup(); - let store = OperatorStateStore::new(1, tmp.path(), mgr).unwrap(); + let store = OperatorStateStore::new( + 1, + tmp.path(), + mgr, + state_block(TEST_OPERATOR_MEMORY), + TEST_OPERATOR_MEMORY, + ) + .unwrap(); store.put(b"k".to_vec(), make_batch(&[1])).await.unwrap(); store.snapshot_epoch(5).unwrap(); @@ -828,7 +960,14 @@ mod tests { #[tokio::test] async fn test_data_survives_snapshot_via_spill() { let (tmp, mgr, _pool) = setup(); - let store = OperatorStateStore::new(1, tmp.path(), mgr).unwrap(); + let store = OperatorStateStore::new( + 1, + tmp.path(), + mgr, + state_block(TEST_OPERATOR_MEMORY), + TEST_OPERATOR_MEMORY, + ) + .unwrap(); let key = b"persist".to_vec(); store.put(key.clone(), make_batch(&[99])).await.unwrap(); @@ -848,7 +987,14 @@ mod tests { #[tokio::test] async fn test_tombstone_hides_immutable_data() { let (tmp, mgr, _pool) = setup(); - let store = OperatorStateStore::new(1, tmp.path(), mgr).unwrap(); + let store = OperatorStateStore::new( + 1, + tmp.path(), + mgr, + state_block(TEST_OPERATOR_MEMORY), + TEST_OPERATOR_MEMORY, + ) + .unwrap(); let key = b"will-die".to_vec(); store.put(key.clone(), make_batch(&[7])).await.unwrap(); diff --git a/src/server/initializer.rs b/src/server/initializer.rs index f8a789f8..86990c62 100644 --- a/src/server/initializer.rs +++ b/src/server/initializer.rs @@ -96,7 +96,7 @@ pub fn build_core_registry() -> ComponentRegistry { let b = ComponentRegistryBuilder::new() .register("WasmCache", initialize_wasm_cache) .register("TaskManager", initialize_task_manager) - .register("GlobalMemoryPool", initialize_global_memory_pool) + .register("MemoryService", initialize_memory_service) .register("JobManager", initialize_job_manager); #[cfg(feature = "python")] let b = b.register("PythonService", initialize_python_service); @@ -156,60 +156,8 @@ fn initialize_python_service(config: &GlobalConfig) -> Result<()> { Ok(()) } -// Streaming heap limits from config + host probe; shared by GlobalMemoryPool and JobManager. -fn resolve_streaming_memory_limits(config: &GlobalConfig) -> (u64, u64) { - use crate::config::system::system_memory_info; - - let mem_info = system_memory_info().ok(); - let total_physical = mem_info.as_ref().map(|m| m.total_physical).unwrap_or(0); - let auto_runtime_bytes = (total_physical as f64 * 0.8) as u64; - - let max_memory_bytes = config - .streaming - .max_memory_bytes - .unwrap_or(if auto_runtime_bytes > 0 { - auto_runtime_bytes - } else { - 256 * 1024 * 1024 - }); - - let per_operator_memory_bytes = config - .streaming - .per_operator_state_memory_bytes - .unwrap_or(64 * 1024 * 1024); - - (max_memory_bytes, per_operator_memory_bytes) -} - -// Singleton global memory pools (streaming + operator state); registered before JobManager. -fn initialize_global_memory_pool(config: &GlobalConfig) -> Result<()> { - use crate::config::system::system_memory_info; - - let mem_info = system_memory_info().ok(); - let total_physical = mem_info.as_ref().map(|m| m.total_physical).unwrap_or(0); - let avail_physical = mem_info.as_ref().map(|m| m.available_physical).unwrap_or(0); - let total_virtual = mem_info.as_ref().map(|m| m.total_virtual).unwrap_or(0); - let avail_virtual = mem_info.as_ref().map(|m| m.available_virtual).unwrap_or(0); - - let (max_memory_bytes, per_operator_memory_bytes) = resolve_streaming_memory_limits(config); - - info!( - total_physical_mb = total_physical / (1024 * 1024), - available_physical_mb = avail_physical / (1024 * 1024), - total_virtual_mb = total_virtual / (1024 * 1024), - available_virtual_mb = avail_virtual / (1024 * 1024), - runtime_memory_mb = max_memory_bytes / (1024 * 1024), - shared_state_memory_mb = per_operator_memory_bytes / (1024 * 1024), - "GlobalMemoryPool: streaming + operator state limits (singleton)" - ); - - crate::runtime::memory::init_global_memory_pool(max_memory_bytes) - .context("Global streaming memory pool initialization failed")?; - crate::runtime::memory::init_global_state_memory_pool(per_operator_memory_bytes) - .context("Global operator state memory pool initialization failed")?; - - info!("GlobalMemoryPool component initialized"); - Ok(()) +fn initialize_memory_service(config: &GlobalConfig) -> Result<()> { + crate::server::memory_service::MemoryService::initialize(config) } fn initialize_job_manager(config: &GlobalConfig) -> Result<()> { @@ -218,7 +166,10 @@ fn initialize_job_manager(config: &GlobalConfig) -> Result<()> { use crate::runtime::streaming::job::{JobManager, StateConfig}; use std::sync::Arc; - let (_, per_operator_memory_bytes) = resolve_streaming_memory_limits(config); + let per_operator_memory_bytes = config + .streaming + .operator_state_store_memory_bytes + .unwrap_or(crate::config::DEFAULT_OPERATOR_STATE_STORE_MEMORY_BYTES); let registry = Arc::new(Registry::new()); let factory = Arc::new(OperatorFactory::new(registry)); @@ -240,9 +191,7 @@ fn initialize_coordinator(_config: &GlobalConfig) -> Result<()> { .context("Dependency violation: Coordinator requires TaskManager")?; crate::runtime::memory::try_global_memory_pool() - .context("Dependency violation: Coordinator requires GlobalMemoryPool")?; - crate::runtime::memory::try_global_state_memory_pool() - .context("Dependency violation: Coordinator requires GlobalMemoryPool (state sub-pool)")?; + .context("Dependency violation: Coordinator requires MemoryService")?; crate::storage::stream_catalog::CatalogManager::global() .context("Dependency violation: Coordinator requires StreamCatalog")?; diff --git a/src/server/memory_service.rs b/src/server/memory_service.rs new file mode 100644 index 00000000..2ba24eee --- /dev/null +++ b/src/server/memory_service.rs @@ -0,0 +1,61 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use anyhow::{Context, Result}; +use tracing::info; + +use crate::config::{ + DEFAULT_OPERATOR_STATE_STORE_MEMORY_BYTES, DEFAULT_STREAMING_RUNTIME_MEMORY_BYTES, GlobalConfig, +}; + +pub struct MemoryService; + +impl MemoryService { + pub fn initialize(config: &GlobalConfig) -> Result<()> { + use crate::config::system::system_memory_info; + + let mem_info = system_memory_info().ok(); + let total_physical = mem_info.as_ref().map(|m| m.total_physical).unwrap_or(0); + let avail_physical = mem_info.as_ref().map(|m| m.available_physical).unwrap_or(0); + let total_virtual = mem_info.as_ref().map(|m| m.total_virtual).unwrap_or(0); + let avail_virtual = mem_info.as_ref().map(|m| m.available_virtual).unwrap_or(0); + + let streaming_runtime_memory_bytes = config + .streaming + .streaming_runtime_memory_bytes + .unwrap_or(DEFAULT_STREAMING_RUNTIME_MEMORY_BYTES); + + let operator_state_store_memory_bytes = config + .streaming + .operator_state_store_memory_bytes + .unwrap_or(DEFAULT_OPERATOR_STATE_STORE_MEMORY_BYTES); + + info!( + total_physical_mb = total_physical / (1024 * 1024), + available_physical_mb = avail_physical / (1024 * 1024), + total_virtual_mb = total_virtual / (1024 * 1024), + available_virtual_mb = avail_virtual / (1024 * 1024), + streaming_runtime_memory_mb = streaming_runtime_memory_bytes / (1024 * 1024), + operator_state_store_memory_mb = operator_state_store_memory_bytes / (1024 * 1024), + "MemoryService: global streaming + operator state pools" + ); + + let total_pool_bytes = + streaming_runtime_memory_bytes.saturating_add(operator_state_store_memory_bytes); + crate::runtime::memory::init_global_memory_pool(total_pool_bytes) + .context("Global memory pool initialization failed")?; + + info!("MemoryService initialized"); + Ok(()) + } +} diff --git a/src/server/mod.rs b/src/server/mod.rs index cb7a4a85..def6ac9e 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -14,6 +14,7 @@ mod handler; mod initializer; +pub mod memory_service; mod service; pub use handler::FunctionStreamServiceImpl; From 9164f4689a55e5163211903c6fcfaa054c028640 Mon Sep 17 00:00:00 2001 From: luoluoyuyu Date: Sun, 19 Apr 2026 17:22:04 +0800 Subject: [PATCH 13/26] update --- src/config/system.rs | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/config/system.rs b/src/config/system.rs index d7a37ddf..1a6d2967 100644 --- a/src/config/system.rs +++ b/src/config/system.rs @@ -103,10 +103,7 @@ mod sys { .arg(name) .output()?; if !output.status.success() { - return Err(io::Error::new( - io::ErrorKind::Other, - format!("sysctl {name} failed"), - )); + return Err(io::Error::other(format!("sysctl {name} failed"))); } String::from_utf8_lossy(&output.stdout) .trim() @@ -166,13 +163,13 @@ mod sys { let mut total = 0u64; let mut used = 0u64; for part in text.split_whitespace() { - if let Some(mb_str) = part.strip_suffix("M") { - if let Ok(mb) = mb_str.parse::() { - if total == 0 { - total = (mb * 1024.0 * 1024.0) as u64; - } else if used == 0 { - used = (mb * 1024.0 * 1024.0) as u64; - } + if let Some(mb_str) = part.strip_suffix("M") + && let Ok(mb) = mb_str.parse::() + { + if total == 0 { + total = (mb * 1024.0 * 1024.0) as u64; + } else if used == 0 { + used = (mb * 1024.0 * 1024.0) as u64; } } } From 85f461b9ecbd87b89f3140ca9abc8ffd1f731afd Mon Sep 17 00:00:00 2001 From: luoluoyuyu Date: Sun, 19 Apr 2026 17:34:28 +0800 Subject: [PATCH 14/26] update --- protocol/proto/storage.proto | 2 -- src/runtime/streaming/job/job_manager.rs | 4 +++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/protocol/proto/storage.proto b/protocol/proto/storage.proto index 20e14862..828bbac5 100644 --- a/protocol/proto/storage.proto +++ b/protocol/proto/storage.proto @@ -53,8 +53,6 @@ message StreamingTableDefinition { bytes fs_program_bytes = 3; string comment = 4; - // User-specified checkpoint interval from WITH clause (e.g. 'checkpoint.interval' = '5000'). - // 0 or unset means use system default. uint64 checkpoint_interval_ms = 5; // Last globally-committed checkpoint epoch. diff --git a/src/runtime/streaming/job/job_manager.rs b/src/runtime/streaming/job/job_manager.rs index 457b2f23..b66a93dc 100644 --- a/src/runtime/streaming/job/job_manager.rs +++ b/src/runtime/streaming/job/job_manager.rs @@ -66,6 +66,8 @@ pub struct StreamingJobDetail { pub program: FsProgram, } +pub const DEFAULT_CHECKPOINT_INTERVAL_MS: u64 = 60 * 1000; + #[derive(Debug, Clone)] pub struct StateConfig { pub max_background_spills: usize, @@ -82,7 +84,7 @@ impl Default for StateConfig { max_background_spills: 4, max_background_compactions: 2, soft_limit_ratio: 0.7, - checkpoint_interval_ms: 10_000, + checkpoint_interval_ms: DEFAULT_CHECKPOINT_INTERVAL_MS, per_operator_memory_bytes: DEFAULT_OPERATOR_STATE_STORE_MEMORY_BYTES, } } From 66186bc60e35aa5cfa0f2f4258d6222aaa78f1dc Mon Sep 17 00:00:00 2001 From: luoluoyuyu Date: Sun, 19 Apr 2026 18:11:39 +0800 Subject: [PATCH 15/26] update --- conf/config.yaml | 2 + src/config/global_config.rs | 10 ++++ src/config/mod.rs | 2 + src/config/streaming_job.rs | 45 ++++++++++++++++++ src/coordinator/execution/executor.rs | 17 ++++--- src/coordinator/mod.rs | 1 + src/coordinator/streaming_table_options.rs | 47 +++++++++++++++++++ src/runtime/streaming/job/job_manager.rs | 23 +++++++-- src/runtime/streaming/state/error.rs | 9 ++++ src/runtime/streaming/state/operator_state.rs | 9 ++++ src/server/initializer.rs | 3 ++ src/sql/common/constants.rs | 2 +- src/sql/common/operator_config.rs | 9 ++++ src/sql/logical_node/aggregate.rs | 8 ++-- src/sql/logical_node/async_udf.rs | 2 +- src/sql/logical_node/join.rs | 2 +- src/sql/logical_node/key_calculation.rs | 2 +- src/sql/logical_node/lookup.rs | 2 +- src/sql/logical_node/projection.rs | 2 +- src/sql/logical_node/remote_table.rs | 2 +- src/sql/logical_node/sink.rs | 4 +- src/sql/logical_node/table_source.rs | 4 +- src/sql/logical_node/updating_aggregate.rs | 2 +- src/sql/logical_node/watermark_node.rs | 2 +- src/sql/logical_node/windows_function.rs | 2 +- src/sql/logical_planner/streaming_planner.rs | 5 ++ src/sql/schema/schema_provider.rs | 8 +++- 27 files changed, 197 insertions(+), 29 deletions(-) create mode 100644 src/config/streaming_job.rs create mode 100644 src/coordinator/streaming_table_options.rs diff --git a/conf/config.yaml b/conf/config.yaml index ea0683f3..1bafb944 100644 --- a/conf/config.yaml +++ b/conf/config.yaml @@ -58,6 +58,8 @@ streaming: # Per stateful operator (join / agg / window): in-memory state store cap; spill when exceeded. # Omitted → 100 MiB. operator_state_store_memory_bytes: 104857600 + checkpoint_interval_ms: 60000 + pipeline_parallelism: 1 # State Storage Configuration # Used to store runtime state data for tasks diff --git a/src/config/global_config.rs b/src/config/global_config.rs index c1960ac9..2f831ef4 100644 --- a/src/config/global_config.rs +++ b/src/config/global_config.rs @@ -17,6 +17,7 @@ use uuid::Uuid; use crate::config::log_config::LogConfig; use crate::config::python_config::PythonConfig; use crate::config::service_config::ServiceConfig; +use crate::config::streaming_job::{ResolvedStreamingJobConfig, StreamingJobConfig}; use crate::config::wasm_config::WasmConfig; /// Default for [`StreamingConfig::streaming_runtime_memory_bytes`] when unset. **200 MiB.** @@ -27,6 +28,8 @@ pub const DEFAULT_OPERATOR_STATE_STORE_MEMORY_BYTES: u64 = 100 * 1024 * 1024; #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct StreamingConfig { + #[serde(flatten)] + pub job: StreamingJobConfig, /// Bytes reserved in the global memory pool for streaming execution (pipeline buffers, /// batch collect, backpressure). #[serde(default)] @@ -36,6 +39,13 @@ pub struct StreamingConfig { pub operator_state_store_memory_bytes: Option, } +impl StreamingConfig { + #[inline] + pub fn resolved_job(&self) -> ResolvedStreamingJobConfig { + self.job.resolve() + } +} + #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct GlobalConfig { pub service: ServiceConfig, diff --git a/src/config/mod.rs b/src/config/mod.rs index 55490088..e523e4fd 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -17,12 +17,14 @@ pub mod paths; pub mod python_config; pub mod service_config; pub mod storage; +pub mod streaming_job; pub mod system; pub mod wasm_config; pub use global_config::{ DEFAULT_OPERATOR_STATE_STORE_MEMORY_BYTES, DEFAULT_STREAMING_RUNTIME_MEMORY_BYTES, GlobalConfig, }; +pub use streaming_job::{DEFAULT_CHECKPOINT_INTERVAL_MS, DEFAULT_PIPELINE_PARALLELISM}; pub use loader::load_global_config; pub use log_config::LogConfig; #[allow(unused_imports)] diff --git a/src/config/streaming_job.rs b/src/config/streaming_job.rs new file mode 100644 index 00000000..46a3b0ef --- /dev/null +++ b/src/config/streaming_job.rs @@ -0,0 +1,45 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use serde::{Deserialize, Serialize}; + +pub const DEFAULT_CHECKPOINT_INTERVAL_MS: u64 = 60 * 1000; +pub const DEFAULT_PIPELINE_PARALLELISM: u32 = 1; + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct StreamingJobConfig { + #[serde(default)] + pub checkpoint_interval_ms: Option, + #[serde(default)] + pub pipeline_parallelism: Option, +} + +#[derive(Debug, Clone, Copy)] +pub struct ResolvedStreamingJobConfig { + pub checkpoint_interval_ms: u64, + pub pipeline_parallelism: u32, +} + +impl StreamingJobConfig { + pub fn resolve(&self) -> ResolvedStreamingJobConfig { + ResolvedStreamingJobConfig { + checkpoint_interval_ms: self + .checkpoint_interval_ms + .filter(|&ms| ms > 0) + .unwrap_or(DEFAULT_CHECKPOINT_INTERVAL_MS), + pipeline_parallelism: self + .pipeline_parallelism + .filter(|&p| p > 0) + .unwrap_or(DEFAULT_PIPELINE_PARALLELISM), + } + } +} diff --git a/src/coordinator/execution/executor.rs b/src/coordinator/execution/executor.rs index 0e7a79b1..78c114d1 100644 --- a/src/coordinator/execution/executor.rs +++ b/src/coordinator/execution/executor.rs @@ -20,6 +20,9 @@ use crate::coordinator::dataset::{ ExecuteResult, ShowCatalogTablesResult, ShowCreateStreamingTableResult, ShowCreateTableResult, ShowFunctionsResult, ShowStreamingTablesResult, empty_record_batch, }; +use crate::coordinator::streaming_table_options::{ + parse_checkpoint_interval_ms, parse_pipeline_parallelism, +}; use crate::coordinator::plan::{ CreateFunctionPlan, CreatePythonFunctionPlan, CreateTablePlan, CreateTablePlanBody, DropFunctionPlan, DropStreamingTablePlan, DropTablePlan, LookupTablePlan, PlanNode, @@ -318,16 +321,18 @@ impl PlanVisitor for Executor { _context: &PlanVisitorContext, ) -> PlanVisitorResult { let execute = || -> Result { - let fs_program: FsProgram = plan.program.clone().into(); + let mut fs_program: FsProgram = plan.program.clone().into(); let job_manager: Arc = Arc::clone(&self.job_manager); + let pipeline_parallelism = parse_pipeline_parallelism(plan.with_options.as_ref()) + .unwrap_or_else(|| job_manager.default_pipeline_parallelism()) + .max(1); + for node in &mut fs_program.nodes { + node.parallelism = pipeline_parallelism; + } let job_id = plan.name.clone(); - let custom_interval: Option = plan - .with_options - .as_ref() - .and_then(|opts| opts.get("checkpoint.interval")) - .and_then(|v| v.parse().ok()); + let custom_interval = parse_checkpoint_interval_ms(plan.with_options.as_ref()); self.catalog_manager .persist_streaming_job( diff --git a/src/coordinator/mod.rs b/src/coordinator/mod.rs index 38d4637f..86598bc5 100644 --- a/src/coordinator/mod.rs +++ b/src/coordinator/mod.rs @@ -19,6 +19,7 @@ mod execution_context; mod plan; mod runtime_context; mod statement; +mod streaming_table_options; mod tool; pub use coordinator::Coordinator; diff --git a/src/coordinator/streaming_table_options.rs b/src/coordinator/streaming_table_options.rs new file mode 100644 index 00000000..51e020b0 --- /dev/null +++ b/src/coordinator/streaming_table_options.rs @@ -0,0 +1,47 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; + +fn parse_positive_u64(raw: &str) -> Option { + let t = raw.trim().trim_matches('\''); + t.parse::().ok().filter(|&v| v > 0) +} + +fn parse_positive_u32(raw: &str) -> Option { + let t = raw.trim().trim_matches('\''); + t.parse::().ok().filter(|&v| v > 0) +} + +pub fn parse_checkpoint_interval_ms(opts: Option<&HashMap>) -> Option { + opts.and_then(|m| m.get("checkpoint.interval")) + .and_then(|s| parse_positive_u64(s)) +} + +pub fn parse_pipeline_parallelism(opts: Option<&HashMap>) -> Option { + opts.and_then(|m| m.get("parallelism")) + .and_then(|s| parse_positive_u32(s)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parses_checkpoint_and_parallelism() { + let mut m = HashMap::new(); + m.insert("checkpoint.interval".to_string(), "30000".to_string()); + m.insert("parallelism".to_string(), "2".to_string()); + assert_eq!(parse_checkpoint_interval_ms(Some(&m)), Some(30_000)); + assert_eq!(parse_pipeline_parallelism(Some(&m)), Some(2)); + } +} diff --git a/src/runtime/streaming/job/job_manager.rs b/src/runtime/streaming/job/job_manager.rs index b66a93dc..549cb314 100644 --- a/src/runtime/streaming/job/job_manager.rs +++ b/src/runtime/streaming/job/job_manager.rs @@ -24,7 +24,10 @@ use tracing::{debug, error, info, warn}; use protocol::function_stream_graph::{ChainedOperator, FsProgram}; -use crate::config::DEFAULT_OPERATOR_STATE_STORE_MEMORY_BYTES; +use crate::config::{ + DEFAULT_CHECKPOINT_INTERVAL_MS, DEFAULT_OPERATOR_STATE_STORE_MEMORY_BYTES, + DEFAULT_PIPELINE_PARALLELISM, +}; use crate::runtime::memory::global_memory_pool; use crate::runtime::streaming::api::context::TaskContext; use crate::runtime::streaming::api::operator::{ConstructedOperator, Operator}; @@ -66,14 +69,13 @@ pub struct StreamingJobDetail { pub program: FsProgram, } -pub const DEFAULT_CHECKPOINT_INTERVAL_MS: u64 = 60 * 1000; - #[derive(Debug, Clone)] pub struct StateConfig { pub max_background_spills: usize, pub max_background_compactions: usize, pub soft_limit_ratio: f64, pub checkpoint_interval_ms: u64, + pub pipeline_parallelism: u32, /// Total bytes shared by all [`crate::runtime::streaming::state::OperatorStateStore`] (global pool). pub per_operator_memory_bytes: u64, } @@ -85,6 +87,7 @@ impl Default for StateConfig { max_background_compactions: 2, soft_limit_ratio: 0.7, checkpoint_interval_ms: DEFAULT_CHECKPOINT_INTERVAL_MS, + pipeline_parallelism: DEFAULT_PIPELINE_PARALLELISM, per_operator_memory_bytes: DEFAULT_OPERATOR_STATE_STORE_MEMORY_BYTES, } } @@ -190,6 +193,11 @@ impl JobManager { } } + #[inline] + pub fn default_pipeline_parallelism(&self) -> u32 { + self.state_config.pipeline_parallelism + } + pub async fn submit_job( &self, job_id: String, @@ -218,6 +226,7 @@ impl JobManager { job_id.clone(), pipeline_id, &node.operators, + node.parallelism, &mut edge_manager, &job_state_dir, job_master_tx.clone(), @@ -445,6 +454,7 @@ impl JobManager { job_id: String, pipeline_id: u32, operators: &[ChainedOperator], + declared_parallelism: u32, edge_manager: &mut EdgeManager, job_state_dir: &Path, _job_master_tx: mpsc::Sender, @@ -492,7 +502,12 @@ impl JobManager { let status = Arc::new(RwLock::new(PipelineStatus::Initializing)); let subtask_index = 0; - let parallelism = 1; + let parallelism = if declared_parallelism > 0 { + declared_parallelism + } else { + self.state_config.pipeline_parallelism + } + .max(1); let per_op = self.state_config.per_operator_memory_bytes; let n_state_ops = pipeline_state_store_operator_count(operators); diff --git a/src/runtime/streaming/state/error.rs b/src/runtime/streaming/state/error.rs index 81ca7e6c..37bc6481 100644 --- a/src/runtime/streaming/state/error.rs +++ b/src/runtime/streaming/state/error.rs @@ -1,5 +1,14 @@ // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. use crossbeam_channel::TrySendError; use thiserror::Error; diff --git a/src/runtime/streaming/state/operator_state.rs b/src/runtime/streaming/state/operator_state.rs index 35b8a3d6..39de499d 100644 --- a/src/runtime/streaming/state/operator_state.rs +++ b/src/runtime/streaming/state/operator_state.rs @@ -1,5 +1,14 @@ // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. use super::error::{Result, StateEngineError}; use super::io_manager::{CompactJob, IoManager, SpillJob}; diff --git a/src/server/initializer.rs b/src/server/initializer.rs index 86990c62..c967831a 100644 --- a/src/server/initializer.rs +++ b/src/server/initializer.rs @@ -170,12 +170,15 @@ fn initialize_job_manager(config: &GlobalConfig) -> Result<()> { .streaming .operator_state_store_memory_bytes .unwrap_or(crate::config::DEFAULT_OPERATOR_STATE_STORE_MEMORY_BYTES); + let job = config.streaming.resolved_job(); let registry = Arc::new(Registry::new()); let factory = Arc::new(OperatorFactory::new(registry)); let state_base_dir = std::env::temp_dir().join("function-stream").join("state"); let state_config = StateConfig { + checkpoint_interval_ms: job.checkpoint_interval_ms, + pipeline_parallelism: job.pipeline_parallelism, per_operator_memory_bytes, ..StateConfig::default() }; diff --git a/src/sql/common/constants.rs b/src/sql/common/constants.rs index 40642cd7..062186af 100644 --- a/src/sql/common/constants.rs +++ b/src/sql/common/constants.rs @@ -107,7 +107,7 @@ pub mod sql_field { } pub mod sql_planning_default { - pub const DEFAULT_PARALLELISM: usize = 4; + pub const DEFAULT_PARALLELISM: usize = 1; pub const PLANNING_TTL_SECS: u64 = 24 * 60 * 60; } diff --git a/src/sql/common/operator_config.rs b/src/sql/common/operator_config.rs index b5360cd7..209bee48 100644 --- a/src/sql/common/operator_config.rs +++ b/src/sql/common/operator_config.rs @@ -1,5 +1,14 @@ // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. use serde::{Deserialize, Serialize}; diff --git a/src/sql/logical_node/aggregate.rs b/src/sql/logical_node/aggregate.rs index d9833c50..9d6e5aed 100644 --- a/src/sql/logical_node/aggregate.rs +++ b/src/sql/logical_node/aggregate.rs @@ -126,7 +126,7 @@ impl StreamWindowAggregateNode { OperatorName::TumblingWindowAggregate, operator_config.encode_to_vec(), format!("TumblingWindow<{}>", operator_config.name), - 1, + planner.default_parallelism(), )) } @@ -176,7 +176,7 @@ impl StreamWindowAggregateNode { OperatorName::SlidingWindowAggregate, operator_config.encode_to_vec(), proto_operator_name::SLIDING_WINDOW_LABEL.to_string(), - 1, + planner.default_parallelism(), )) } @@ -243,7 +243,7 @@ impl StreamWindowAggregateNode { OperatorName::SessionWindowAggregate, operator_config.encode_to_vec(), operator_config.name.clone(), - 1, + planner.default_parallelism(), )) } @@ -299,7 +299,7 @@ impl StreamWindowAggregateNode { OperatorName::TumblingWindowAggregate, operator_config.encode_to_vec(), proto_operator_name::INSTANT_WINDOW_LABEL.to_string(), - 1, + planner.default_parallelism(), )) } } diff --git a/src/sql/logical_node/async_udf.rs b/src/sql/logical_node/async_udf.rs index 6cd2da7b..1c35398e 100644 --- a/src/sql/logical_node/async_udf.rs +++ b/src/sql/logical_node/async_udf.rs @@ -160,7 +160,7 @@ impl StreamingOperatorBlueprint for AsyncFunctionExecutionNode { OperatorName::AsyncUdf, operator_config.encode_to_vec(), format!("AsyncUdf<{}>", self.operator_name), - 1, + planner.default_parallelism(), ); let upstream_schema = input_schemas.remove(0); diff --git a/src/sql/logical_node/join.rs b/src/sql/logical_node/join.rs index ea142d0a..15631f1f 100644 --- a/src/sql/logical_node/join.rs +++ b/src/sql/logical_node/join.rs @@ -191,7 +191,7 @@ impl StreamingOperatorBlueprint for StreamingJoinNode { self.determine_operator_type(), operator_config.encode_to_vec(), runtime_operator_kind::STREAMING_JOIN.to_string(), - 1, + planner.default_parallelism(), ); let left_edge = diff --git a/src/sql/logical_node/key_calculation.rs b/src/sql/logical_node/key_calculation.rs index 6bcad784..ad0e5db0 100644 --- a/src/sql/logical_node/key_calculation.rs +++ b/src/sql/logical_node/key_calculation.rs @@ -238,7 +238,7 @@ impl StreamingOperatorBlueprint for KeyExtractionNode { engine_operator_name, protobuf_payload, format!("Key<{}>", self.operator_label.as_deref().unwrap_or("_")), - 1, + planner.default_parallelism(), ); let data_edge = diff --git a/src/sql/logical_node/lookup.rs b/src/sql/logical_node/lookup.rs index 00f624a7..c060ba82 100644 --- a/src/sql/logical_node/lookup.rs +++ b/src/sql/logical_node/lookup.rs @@ -199,7 +199,7 @@ impl StreamingOperatorBlueprint for StreamReferenceJoinNode { "DictionaryJoin<{}>", self.external_dictionary.table_identifier ), - 1, + planner.default_parallelism(), ); let incoming_edge = diff --git a/src/sql/logical_node/projection.rs b/src/sql/logical_node/projection.rs index df55e575..3c5cfccb 100644 --- a/src/sql/logical_node/projection.rs +++ b/src/sql/logical_node/projection.rs @@ -170,7 +170,7 @@ impl StreamingOperatorBlueprint for StreamProjectionNode { OperatorName::Projection, operator_config.encode_to_vec(), label, - 1, + planner.default_parallelism(), ); let routing_strategy = if self.requires_shuffle { diff --git a/src/sql/logical_node/remote_table.rs b/src/sql/logical_node/remote_table.rs index d43a87e0..bde1d47f 100644 --- a/src/sql/logical_node/remote_table.rs +++ b/src/sql/logical_node/remote_table.rs @@ -119,7 +119,7 @@ impl StreamingOperatorBlueprint for RemoteTableBoundaryNode { OperatorName::Value, operator_payload, self.table_identifier.to_string(), - 1, + planner.default_parallelism(), ); let routing_edges: Vec = input_schemas diff --git a/src/sql/logical_node/sink.rs b/src/sql/logical_node/sink.rs index dbfcaa55..2edf8f27 100644 --- a/src/sql/logical_node/sink.rs +++ b/src/sql/logical_node/sink.rs @@ -149,7 +149,7 @@ impl StreamingOperatorBlueprint for StreamEgressNode { fn compile_to_graph_node( &self, - _planner: &Planner, + planner: &Planner, node_index: usize, input_schemas: Vec, ) -> Result { @@ -167,7 +167,7 @@ impl StreamingOperatorBlueprint for StreamEgressNode { OperatorName::ConnectorSink, operator_payload, operator_description, - 1, + planner.default_parallelism(), ); let routing_edges: Vec = input_schemas diff --git a/src/sql/logical_node/table_source.rs b/src/sql/logical_node/table_source.rs index 65f4459f..b1c6bfdd 100644 --- a/src/sql/logical_node/table_source.rs +++ b/src/sql/logical_node/table_source.rs @@ -147,7 +147,7 @@ impl StreamingOperatorBlueprint for StreamIngestionNode { fn compile_to_graph_node( &self, - _compiler_context: &Planner, + compiler_context: &Planner, node_id_sequence: usize, upstream_schemas: Vec, ) -> Result { @@ -167,7 +167,7 @@ impl StreamingOperatorBlueprint for StreamIngestionNode { OperatorName::ConnectorSource, connector_payload, operator_description, - 1, + compiler_context.default_parallelism(), ); Ok(CompiledTopologyNode::new(execution_unit, vec![])) diff --git a/src/sql/logical_node/updating_aggregate.rs b/src/sql/logical_node/updating_aggregate.rs index 598d20eb..7e940cee 100644 --- a/src/sql/logical_node/updating_aggregate.rs +++ b/src/sql/logical_node/updating_aggregate.rs @@ -224,7 +224,7 @@ impl StreamingOperatorBlueprint for ContinuousAggregateNode { OperatorName::UpdatingAggregate, operator_config.encode_to_vec(), proto_operator_name::UPDATING_AGGREGATE.to_string(), - 1, + planner.default_parallelism(), ); let shuffle_edge = diff --git a/src/sql/logical_node/watermark_node.rs b/src/sql/logical_node/watermark_node.rs index 7c83c429..9a8fc9d6 100644 --- a/src/sql/logical_node/watermark_node.rs +++ b/src/sql/logical_node/watermark_node.rs @@ -209,7 +209,7 @@ impl StreamingOperatorBlueprint for EventTimeWatermarkNode { OperatorName::ExpressionWatermark, operator_config.encode_to_vec(), runtime_operator_kind::WATERMARK_GENERATOR.to_string(), - 1, + planner.default_parallelism(), ); let incoming_edge = LogicalEdge::project_all( diff --git a/src/sql/logical_node/windows_function.rs b/src/sql/logical_node/windows_function.rs index a79ceff3..8effc2ff 100644 --- a/src/sql/logical_node/windows_function.rs +++ b/src/sql/logical_node/windows_function.rs @@ -169,7 +169,7 @@ impl StreamingOperatorBlueprint for StreamingWindowFunctionNode { OperatorName::WindowFunction, operator_config.encode_to_vec(), runtime_operator_kind::STREAMING_WINDOW_EVALUATOR.to_string(), - 1, + planner.default_parallelism(), ); let routing_edge = diff --git a/src/sql/logical_planner/streaming_planner.rs b/src/sql/logical_planner/streaming_planner.rs index 4619fb3f..a8545c08 100644 --- a/src/sql/logical_planner/streaming_planner.rs +++ b/src/sql/logical_planner/streaming_planner.rs @@ -96,6 +96,11 @@ pub(crate) struct Planner<'a> { } impl<'a> Planner<'a> { + #[inline] + pub(crate) fn default_parallelism(&self) -> usize { + self.schema_provider.default_parallelism() + } + pub(crate) fn new( schema_provider: &'a StreamSchemaProvider, session_state: &'a SessionState, diff --git a/src/sql/schema/schema_provider.rs b/src/sql/schema/schema_provider.rs index d5405dd2..bbd2eef8 100644 --- a/src/sql/schema/schema_provider.rs +++ b/src/sql/schema/schema_provider.rs @@ -29,7 +29,7 @@ use crate::sql::common::constants::{planning_placeholder_udf, window_fn}; use crate::sql::logical_node::logical::{DylibUdfConfig, LogicalProgram}; use crate::sql::schema::table::Table as CatalogTable; use crate::sql::schema::utils::window_arrow_struct; -use crate::sql::types::{PlanningOptions, PlanningPlaceholderUdf}; +use crate::sql::types::{PlanningOptions, PlanningPlaceholderUdf, SqlConfig}; pub type ObjectName = UniCase; @@ -129,6 +129,7 @@ pub struct StreamPlanningContext { pub config_options: datafusion::config::ConfigOptions, pub planning_options: PlanningOptions, pub analyzer: Analyzer, + pub sql_config: SqlConfig, } /// Back-compat name for [`StreamPlanningContext`]. @@ -139,6 +140,11 @@ impl StreamPlanningContext { StreamPlanningContextBuilder::default() } + #[inline] + pub fn default_parallelism(&self) -> usize { + self.sql_config.default_parallelism + } + /// Same registration order as the historical `StreamSchemaProvider::new` (placeholders, then DataFusion defaults). pub fn new() -> Self { Self::builder() From 041e1307443374eac1873a3398108f346ed1e371 Mon Sep 17 00:00:00 2001 From: luoluoyuyu Date: Sun, 19 Apr 2026 18:18:40 +0800 Subject: [PATCH 16/26] update --- src/config/mod.rs | 2 +- src/coordinator/execution/executor.rs | 6 +++--- src/sql/common/constants.rs | 2 ++ src/sql/logical_node/aggregate.rs | 16 ++++++++++++---- src/sql/logical_node/updating_aggregate.rs | 8 +++++++- src/sql/logical_node/windows_function.rs | 8 +++++++- src/sql/logical_planner/streaming_planner.rs | 7 +++++++ 7 files changed, 39 insertions(+), 10 deletions(-) diff --git a/src/config/mod.rs b/src/config/mod.rs index e523e4fd..e60dcfde 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -24,7 +24,6 @@ pub mod wasm_config; pub use global_config::{ DEFAULT_OPERATOR_STATE_STORE_MEMORY_BYTES, DEFAULT_STREAMING_RUNTIME_MEMORY_BYTES, GlobalConfig, }; -pub use streaming_job::{DEFAULT_CHECKPOINT_INTERVAL_MS, DEFAULT_PIPELINE_PARALLELISM}; pub use loader::load_global_config; pub use log_config::LogConfig; #[allow(unused_imports)] @@ -36,3 +35,4 @@ pub use paths::{ }; #[cfg(feature = "python")] pub use python_config::PythonConfig; +pub use streaming_job::{DEFAULT_CHECKPOINT_INTERVAL_MS, DEFAULT_PIPELINE_PARALLELISM}; diff --git a/src/coordinator/execution/executor.rs b/src/coordinator/execution/executor.rs index 78c114d1..0ae5b874 100644 --- a/src/coordinator/execution/executor.rs +++ b/src/coordinator/execution/executor.rs @@ -20,9 +20,6 @@ use crate::coordinator::dataset::{ ExecuteResult, ShowCatalogTablesResult, ShowCreateStreamingTableResult, ShowCreateTableResult, ShowFunctionsResult, ShowStreamingTablesResult, empty_record_batch, }; -use crate::coordinator::streaming_table_options::{ - parse_checkpoint_interval_ms, parse_pipeline_parallelism, -}; use crate::coordinator::plan::{ CreateFunctionPlan, CreatePythonFunctionPlan, CreateTablePlan, CreateTablePlanBody, DropFunctionPlan, DropStreamingTablePlan, DropTablePlan, LookupTablePlan, PlanNode, @@ -31,6 +28,9 @@ use crate::coordinator::plan::{ StartFunctionPlan, StopFunctionPlan, StreamingTable, StreamingTableConnectorPlan, }; use crate::coordinator::statement::{ConfigSource, FunctionSource}; +use crate::coordinator::streaming_table_options::{ + parse_checkpoint_interval_ms, parse_pipeline_parallelism, +}; use crate::runtime::streaming::job::JobManager; use crate::runtime::streaming::protocol::control::StopMode; use crate::runtime::wasm::taskexecutor::TaskManager; diff --git a/src/sql/common/constants.rs b/src/sql/common/constants.rs index 062186af..112db040 100644 --- a/src/sql/common/constants.rs +++ b/src/sql/common/constants.rs @@ -108,6 +108,8 @@ pub mod sql_field { pub mod sql_planning_default { pub const DEFAULT_PARALLELISM: usize = 1; + /// Parallelism for aggregations that run after `KeyBy` / shuffle on non-empty routing keys. + pub const KEYED_AGGREGATE_DEFAULT_PARALLELISM: usize = 8; pub const PLANNING_TTL_SECS: u64 = 24 * 60 * 60; } diff --git a/src/sql/logical_node/aggregate.rs b/src/sql/logical_node/aggregate.rs index 9d6e5aed..3a6d3677 100644 --- a/src/sql/logical_node/aggregate.rs +++ b/src/sql/logical_node/aggregate.rs @@ -64,6 +64,14 @@ multifield_partial_ord!( ); impl StreamWindowAggregateNode { + fn parallelism_after_keyed_shuffle(&self, planner: &Planner) -> usize { + if self.partition_keys.is_empty() { + planner.default_parallelism() + } else { + planner.keyed_aggregate_parallelism() + } + } + /// Safely constructs a new node, computing the final projection without panicking. pub fn try_new( window_spec: WindowBehavior, @@ -126,7 +134,7 @@ impl StreamWindowAggregateNode { OperatorName::TumblingWindowAggregate, operator_config.encode_to_vec(), format!("TumblingWindow<{}>", operator_config.name), - planner.default_parallelism(), + self.parallelism_after_keyed_shuffle(planner), )) } @@ -176,7 +184,7 @@ impl StreamWindowAggregateNode { OperatorName::SlidingWindowAggregate, operator_config.encode_to_vec(), proto_operator_name::SLIDING_WINDOW_LABEL.to_string(), - planner.default_parallelism(), + self.parallelism_after_keyed_shuffle(planner), )) } @@ -243,7 +251,7 @@ impl StreamWindowAggregateNode { OperatorName::SessionWindowAggregate, operator_config.encode_to_vec(), operator_config.name.clone(), - planner.default_parallelism(), + self.parallelism_after_keyed_shuffle(planner), )) } @@ -299,7 +307,7 @@ impl StreamWindowAggregateNode { OperatorName::TumblingWindowAggregate, operator_config.encode_to_vec(), proto_operator_name::INSTANT_WINDOW_LABEL.to_string(), - planner.default_parallelism(), + self.parallelism_after_keyed_shuffle(planner), )) } } diff --git a/src/sql/logical_node/updating_aggregate.rs b/src/sql/logical_node/updating_aggregate.rs index 7e940cee..8f5bcd31 100644 --- a/src/sql/logical_node/updating_aggregate.rs +++ b/src/sql/logical_node/updating_aggregate.rs @@ -218,13 +218,19 @@ impl StreamingOperatorBlueprint for ContinuousAggregateNode { let operator_config = self.compile_operator_config(planner, &upstream_schema)?; + let parallelism = if self.partition_key_indices.is_empty() { + planner.default_parallelism() + } else { + planner.keyed_aggregate_parallelism() + }; + let logical_node = LogicalNode::single( node_index as u32, format!("updating_aggregate_{node_index}"), OperatorName::UpdatingAggregate, operator_config.encode_to_vec(), proto_operator_name::UPDATING_AGGREGATE.to_string(), - planner.default_parallelism(), + parallelism, ); let shuffle_edge = diff --git a/src/sql/logical_node/windows_function.rs b/src/sql/logical_node/windows_function.rs index 8effc2ff..198e0d34 100644 --- a/src/sql/logical_node/windows_function.rs +++ b/src/sql/logical_node/windows_function.rs @@ -163,13 +163,19 @@ impl StreamingOperatorBlueprint for StreamingWindowFunctionNode { window_function_plan: evaluation_plan_payload, }; + let parallelism = if self.partition_key_indices.is_empty() { + planner.default_parallelism() + } else { + planner.keyed_aggregate_parallelism() + }; + let logical_node = LogicalNode::single( node_index as u32, format!("window_function_{node_index}"), OperatorName::WindowFunction, operator_config.encode_to_vec(), runtime_operator_kind::STREAMING_WINDOW_EVALUATOR.to_string(), - planner.default_parallelism(), + parallelism, ); let routing_edge = diff --git a/src/sql/logical_planner/streaming_planner.rs b/src/sql/logical_planner/streaming_planner.rs index a8545c08..4657c70b 100644 --- a/src/sql/logical_planner/streaming_planner.rs +++ b/src/sql/logical_planner/streaming_planner.rs @@ -42,6 +42,7 @@ use datafusion_common::TableReference; use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec; use datafusion_proto::physical_plan::to_proto::serialize_physical_expr; +use crate::sql::common::constants::sql_planning_default; use crate::sql::common::{FsSchema, FsSchemaRef}; use crate::sql::logical_node::debezium::{ PACK_NODE_NAME, UNROLL_NODE_NAME, UnrollDebeziumPayloadNode, @@ -101,6 +102,12 @@ impl<'a> Planner<'a> { self.schema_provider.default_parallelism() } + /// Parallelism for operators that consume a keyed shuffle (non-empty partition keys). + #[inline] + pub(crate) fn keyed_aggregate_parallelism(&self) -> usize { + sql_planning_default::KEYED_AGGREGATE_DEFAULT_PARALLELISM + } + pub(crate) fn new( schema_provider: &'a StreamSchemaProvider, session_state: &'a SessionState, From 4eb4d533a96c9bffb5c005ea14bb3f5eac44c510 Mon Sep 17 00:00:00 2001 From: luoluoyuyu Date: Sun, 19 Apr 2026 18:53:08 +0800 Subject: [PATCH 17/26] update --- conf/config.yaml | 14 ++++---- src/config/global_config.rs | 14 ++++---- src/config/streaming_job.rs | 9 +++++ src/coordinator/execution/executor.rs | 15 ++++++--- src/server/initializer.rs | 7 ++++ src/sql/common/constants.rs | 2 ++ src/sql/logical_node/aggregate.rs | 9 +++-- src/sql/logical_node/key_calculation.rs | 2 +- .../logical_node/logical/operator_chain.rs | 13 ++++++++ src/sql/logical_node/updating_aggregate.rs | 6 +--- src/sql/logical_node/windows_function.rs | 6 +--- .../logical_planner/optimizers/chaining.rs | 33 ++++++++++++++++++- src/sql/logical_planner/streaming_planner.rs | 5 +++ src/sql/mod.rs | 1 + src/sql/schema/schema_provider.rs | 11 +++++-- src/sql/types/mod.rs | 3 ++ 16 files changed, 113 insertions(+), 37 deletions(-) diff --git a/conf/config.yaml b/conf/config.yaml index 1bafb944..c83809c7 100644 --- a/conf/config.yaml +++ b/conf/config.yaml @@ -51,15 +51,17 @@ wasm: # Streaming Runtime Configuration streaming: - # Bytes in the global memory pool for streaming execution: pipeline buffers, batch collect, - # backpressure. Omitted → 200 MiB. - streaming_runtime_memory_bytes: 209715200 + # Global memory pool for streaming pipeline execution (buffers, batch collect, backpressure). + # Default / example: 10 MiB (10485760 bytes). + streaming_runtime_memory_bytes: 10485760 - # Per stateful operator (join / agg / window): in-memory state store cap; spill when exceeded. - # Omitted → 100 MiB. - operator_state_store_memory_bytes: 104857600 + # Per stateful operator (join / agg / window): in-memory state store cap before spill. + # Default / example: 5 MiB (5242880 bytes). + operator_state_store_memory_bytes: 5242880 checkpoint_interval_ms: 60000 pipeline_parallelism: 1 + # KeyBy (key extraction) operator pipeline parallelism in planned streaming jobs. + key_by_parallelism: 1 # State Storage Configuration # Used to store runtime state data for tasks diff --git a/src/config/global_config.rs b/src/config/global_config.rs index 2f831ef4..dcfbcf5c 100644 --- a/src/config/global_config.rs +++ b/src/config/global_config.rs @@ -20,21 +20,21 @@ use crate::config::service_config::ServiceConfig; use crate::config::streaming_job::{ResolvedStreamingJobConfig, StreamingJobConfig}; use crate::config::wasm_config::WasmConfig; -/// Default for [`StreamingConfig::streaming_runtime_memory_bytes`] when unset. **200 MiB.** -pub const DEFAULT_STREAMING_RUNTIME_MEMORY_BYTES: u64 = 200 * 1024 * 1024; +/// Default for [`StreamingConfig::streaming_runtime_memory_bytes`] when unset. **10 MiB** (pipeline buffers, backpressure). +pub const DEFAULT_STREAMING_RUNTIME_MEMORY_BYTES: u64 = 10 * 1024 * 1024; -/// Default for [`StreamingConfig::operator_state_store_memory_bytes`] when unset. **100 MiB.** -pub const DEFAULT_OPERATOR_STATE_STORE_MEMORY_BYTES: u64 = 100 * 1024 * 1024; +/// Default for [`StreamingConfig::operator_state_store_memory_bytes`] when unset. **5 MiB** per stateful operator cap. +pub const DEFAULT_OPERATOR_STATE_STORE_MEMORY_BYTES: u64 = 5 * 1024 * 1024; #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct StreamingConfig { #[serde(flatten)] pub job: StreamingJobConfig, - /// Bytes reserved in the global memory pool for streaming execution (pipeline buffers, - /// batch collect, backpressure). + /// Bytes reserved in the global memory pool for streaming pipeline execution (buffers, + /// batch collect, backpressure). Default 10 MiB. #[serde(default)] pub streaming_runtime_memory_bytes: Option, - /// Per stateful operator: in-memory state store cap before spill. + /// Per stateful operator: in-memory state store cap before spill. Default 5 MiB. #[serde(default)] pub operator_state_store_memory_bytes: Option, } diff --git a/src/config/streaming_job.rs b/src/config/streaming_job.rs index 46a3b0ef..6ea45609 100644 --- a/src/config/streaming_job.rs +++ b/src/config/streaming_job.rs @@ -14,6 +14,7 @@ use serde::{Deserialize, Serialize}; pub const DEFAULT_CHECKPOINT_INTERVAL_MS: u64 = 60 * 1000; pub const DEFAULT_PIPELINE_PARALLELISM: u32 = 1; +pub const DEFAULT_KEY_BY_PARALLELISM: u32 = 1; #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct StreamingJobConfig { @@ -21,12 +22,16 @@ pub struct StreamingJobConfig { pub checkpoint_interval_ms: Option, #[serde(default)] pub pipeline_parallelism: Option, + /// Physical parallelism for KeyBy / key-extraction operators in planned streaming graphs. + #[serde(default)] + pub key_by_parallelism: Option, } #[derive(Debug, Clone, Copy)] pub struct ResolvedStreamingJobConfig { pub checkpoint_interval_ms: u64, pub pipeline_parallelism: u32, + pub key_by_parallelism: u32, } impl StreamingJobConfig { @@ -40,6 +45,10 @@ impl StreamingJobConfig { .pipeline_parallelism .filter(|&p| p > 0) .unwrap_or(DEFAULT_PIPELINE_PARALLELISM), + key_by_parallelism: self + .key_by_parallelism + .filter(|&p| p > 0) + .unwrap_or(DEFAULT_KEY_BY_PARALLELISM), } } } diff --git a/src/coordinator/execution/executor.rs b/src/coordinator/execution/executor.rs index 0ae5b874..6fb03134 100644 --- a/src/coordinator/execution/executor.rs +++ b/src/coordinator/execution/executor.rs @@ -323,11 +323,16 @@ impl PlanVisitor for Executor { let execute = || -> Result { let mut fs_program: FsProgram = plan.program.clone().into(); let job_manager: Arc = Arc::clone(&self.job_manager); - let pipeline_parallelism = parse_pipeline_parallelism(plan.with_options.as_ref()) - .unwrap_or_else(|| job_manager.default_pipeline_parallelism()) - .max(1); - for node in &mut fs_program.nodes { - node.parallelism = pipeline_parallelism; + // Only override per-node parallelism when CREATE STREAMING TABLE specifies + // `WITH (parallelism = N)`. Otherwise keep planner-assigned values (e.g. keyed + // aggregates defaulting to a higher parallelism than the job-wide default). + if let Some(pipeline_parallelism) = + parse_pipeline_parallelism(plan.with_options.as_ref()) + { + let p = pipeline_parallelism.max(1); + for node in &mut fs_program.nodes { + node.parallelism = p; + } } let job_id = plan.name.clone(); diff --git a/src/server/initializer.rs b/src/server/initializer.rs index c967831a..a20f2ffb 100644 --- a/src/server/initializer.rs +++ b/src/server/initializer.rs @@ -19,6 +19,12 @@ use crate::config::GlobalConfig; pub type InitializerFn = fn(&GlobalConfig) -> Result<()>; +fn initialize_streaming_sql_planning(config: &GlobalConfig) -> Result<()> { + let job = config.streaming.resolved_job(); + crate::sql::planning_runtime::install_sql_planning_from_streaming_job(&job); + Ok(()) +} + #[derive(Clone)] pub struct Component { pub name: &'static str, @@ -94,6 +100,7 @@ impl ComponentRegistry { pub fn build_core_registry() -> ComponentRegistry { let builder = { let b = ComponentRegistryBuilder::new() + .register("StreamingSqlPlanning", initialize_streaming_sql_planning) .register("WasmCache", initialize_wasm_cache) .register("TaskManager", initialize_task_manager) .register("MemoryService", initialize_memory_service) diff --git a/src/sql/common/constants.rs b/src/sql/common/constants.rs index 112db040..19fdbcb3 100644 --- a/src/sql/common/constants.rs +++ b/src/sql/common/constants.rs @@ -108,6 +108,8 @@ pub mod sql_field { pub mod sql_planning_default { pub const DEFAULT_PARALLELISM: usize = 1; + /// Default physical parallelism for `KeyBy` / key-extraction pipelines (configurable via YAML). + pub const DEFAULT_KEY_BY_PARALLELISM: usize = 1; /// Parallelism for aggregations that run after `KeyBy` / shuffle on non-empty routing keys. pub const KEYED_AGGREGATE_DEFAULT_PARALLELISM: usize = 8; pub const PLANNING_TTL_SECS: u64 = 24 * 60 * 60; diff --git a/src/sql/logical_node/aggregate.rs b/src/sql/logical_node/aggregate.rs index 3a6d3677..1e288ab5 100644 --- a/src/sql/logical_node/aggregate.rs +++ b/src/sql/logical_node/aggregate.rs @@ -64,12 +64,11 @@ multifield_partial_ord!( ); impl StreamWindowAggregateNode { + /// This node is only emitted after `KeyExtractionNode` in streaming rewrites; `partition_keys` + /// may be empty when GROUP BY is only a window call (window column stripped from key list), + /// but the pipeline still consumes a shuffle — use keyed aggregate parallelism. fn parallelism_after_keyed_shuffle(&self, planner: &Planner) -> usize { - if self.partition_keys.is_empty() { - planner.default_parallelism() - } else { - planner.keyed_aggregate_parallelism() - } + planner.keyed_aggregate_parallelism() } /// Safely constructs a new node, computing the final projection without panicking. diff --git a/src/sql/logical_node/key_calculation.rs b/src/sql/logical_node/key_calculation.rs index ad0e5db0..ec83e108 100644 --- a/src/sql/logical_node/key_calculation.rs +++ b/src/sql/logical_node/key_calculation.rs @@ -238,7 +238,7 @@ impl StreamingOperatorBlueprint for KeyExtractionNode { engine_operator_name, protobuf_payload, format!("Key<{}>", self.operator_label.as_deref().unwrap_or("_")), - planner.default_parallelism(), + planner.key_by_parallelism(), ); let data_edge = diff --git a/src/sql/logical_node/logical/operator_chain.rs b/src/sql/logical_node/logical/operator_chain.rs index 34a01a5c..be8f3b53 100644 --- a/src/sql/logical_node/logical/operator_chain.rs +++ b/src/sql/logical_node/logical/operator_chain.rs @@ -128,4 +128,17 @@ impl OperatorChain { pub fn is_sink(&self) -> bool { self.operators[0].operator_name == OperatorName::ConnectorSink } + + /// Operators safe to run at a higher upstream `TaskContext::parallelism` when fused after a + /// stateful node (e.g. window aggregate @ 8 → projection @ 1). + pub fn is_parallelism_upstream_expandable(&self) -> bool { + self.operators.iter().all(|op| { + matches!( + op.operator_name, + OperatorName::Projection + | OperatorName::Value + | OperatorName::ExpressionWatermark + ) + }) + } } diff --git a/src/sql/logical_node/updating_aggregate.rs b/src/sql/logical_node/updating_aggregate.rs index 8f5bcd31..0ddb2b28 100644 --- a/src/sql/logical_node/updating_aggregate.rs +++ b/src/sql/logical_node/updating_aggregate.rs @@ -218,11 +218,7 @@ impl StreamingOperatorBlueprint for ContinuousAggregateNode { let operator_config = self.compile_operator_config(planner, &upstream_schema)?; - let parallelism = if self.partition_key_indices.is_empty() { - planner.default_parallelism() - } else { - planner.keyed_aggregate_parallelism() - }; + let parallelism = planner.keyed_aggregate_parallelism(); let logical_node = LogicalNode::single( node_index as u32, diff --git a/src/sql/logical_node/windows_function.rs b/src/sql/logical_node/windows_function.rs index 198e0d34..9be37382 100644 --- a/src/sql/logical_node/windows_function.rs +++ b/src/sql/logical_node/windows_function.rs @@ -163,11 +163,7 @@ impl StreamingOperatorBlueprint for StreamingWindowFunctionNode { window_function_plan: evaluation_plan_payload, }; - let parallelism = if self.partition_key_indices.is_empty() { - planner.default_parallelism() - } else { - planner.keyed_aggregate_parallelism() - }; + let parallelism = planner.keyed_aggregate_parallelism(); let logical_node = LogicalNode::single( node_index as u32, diff --git a/src/sql/logical_planner/optimizers/chaining.rs b/src/sql/logical_planner/optimizers/chaining.rs index 8260df19..59287d88 100644 --- a/src/sql/logical_planner/optimizers/chaining.rs +++ b/src/sql/logical_planner/optimizers/chaining.rs @@ -45,9 +45,14 @@ impl Optimizer for ChainingOptimizer { let source_node = plan.node_weight(node_idx).expect("Source node missing"); let target_node = plan.node_weight(target_idx).expect("Target node missing"); + let parallelism_ok = source_node.parallelism == target_node.parallelism + || target_node + .operator_chain + .is_parallelism_upstream_expandable(); + if source_node.operator_chain.is_source() || target_node.operator_chain.is_sink() - || source_node.parallelism != target_node.parallelism + || !parallelism_ok { continue; } @@ -93,6 +98,10 @@ impl Optimizer for ChainingOptimizer { source_node.description = format!("{} -> {}", source_node.description, target_node.description); + source_node.parallelism = source_node + .parallelism + .max(target_node.parallelism); + source_node .operator_chain .operators @@ -150,6 +159,28 @@ mod tests { ) } + /// Window aggregate at higher default parallelism may forward into projection @ 1: still fuse + /// so each branch does not reserve a separate global state-memory block for the same sub-chain. + #[test] + fn fusion_stateful_high_parallelism_into_expandable_low() { + let mut g = LogicalGraph::new(); + let n0 = g.add_node(source_node()); + let n1 = g.add_node(proj_node(1, "tumble")); + let n2 = g.add_node(proj_node(2, "proj")); + let n1w = g.node_weight_mut(n1).unwrap(); + n1w.parallelism = 8; + let e = forward_edge(); + g.add_edge(n0, n1, e.clone()); + g.add_edge(n1, n2, e); + + let changed = ChainingOptimizer {}.optimize_once(&mut g); + assert!(changed); + assert_eq!(g.node_count(), 2); + let fused = g.node_weights().find(|n| n.description.contains("->")).unwrap(); + assert_eq!(fused.parallelism, 8); + assert_eq!(fused.operator_chain.len(), 2); + } + /// Regression: upstream at last `NodeIndex` + remove non-last downstream swaps indices. #[test] fn fusion_remaps_when_upstream_was_last_node_index() { diff --git a/src/sql/logical_planner/streaming_planner.rs b/src/sql/logical_planner/streaming_planner.rs index 4657c70b..1e999c2a 100644 --- a/src/sql/logical_planner/streaming_planner.rs +++ b/src/sql/logical_planner/streaming_planner.rs @@ -102,6 +102,11 @@ impl<'a> Planner<'a> { self.schema_provider.default_parallelism() } + #[inline] + pub(crate) fn key_by_parallelism(&self) -> usize { + self.schema_provider.key_by_parallelism() + } + /// Parallelism for operators that consume a keyed shuffle (non-empty partition keys). #[inline] pub(crate) fn keyed_aggregate_parallelism(&self) -> usize { diff --git a/src/sql/mod.rs b/src/sql/mod.rs index 71dd4dd1..529c7a2d 100644 --- a/src/sql/mod.rs +++ b/src/sql/mod.rs @@ -19,6 +19,7 @@ pub mod logical_node; pub mod logical_planner; pub mod parse; pub mod physical; +pub(crate) mod planning_runtime; pub mod schema; pub mod types; diff --git a/src/sql/schema/schema_provider.rs b/src/sql/schema/schema_provider.rs index bbd2eef8..26fd43e8 100644 --- a/src/sql/schema/schema_provider.rs +++ b/src/sql/schema/schema_provider.rs @@ -145,14 +145,21 @@ impl StreamPlanningContext { self.sql_config.default_parallelism } + #[inline] + pub fn key_by_parallelism(&self) -> usize { + self.sql_config.key_by_parallelism + } + /// Same registration order as the historical `StreamSchemaProvider::new` (placeholders, then DataFusion defaults). pub fn new() -> Self { - Self::builder() + let mut ctx = Self::builder() .with_streaming_extensions() .expect("streaming extensions") .with_default_functions() .expect("default functions") - .build() + .build(); + ctx.sql_config = crate::sql::planning_runtime::sql_planning_snapshot(); + ctx } pub fn register_stream_table(&mut self, table: StreamTable) { diff --git a/src/sql/types/mod.rs b/src/sql/types/mod.rs index c9d80681..d5124bcc 100644 --- a/src/sql/types/mod.rs +++ b/src/sql/types/mod.rs @@ -38,12 +38,15 @@ pub enum ProcessingMode { #[derive(Clone, Debug)] pub struct SqlConfig { pub default_parallelism: usize, + /// Physical pipeline parallelism for [`KeyExtractionNode`](crate::sql::logical_node::key_calculation::KeyExtractionNode) / KeyBy. + pub key_by_parallelism: usize, } impl Default for SqlConfig { fn default() -> Self { Self { default_parallelism: sql_planning_default::DEFAULT_PARALLELISM, + key_by_parallelism: sql_planning_default::DEFAULT_KEY_BY_PARALLELISM, } } } From 06e69053db56ba3caea3256f06db7a1dc9c19495 Mon Sep 17 00:00:00 2001 From: luoluoyuyu Date: Sun, 19 Apr 2026 18:53:22 +0800 Subject: [PATCH 18/26] update --- src/sql/planning_runtime.rs | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 src/sql/planning_runtime.rs diff --git a/src/sql/planning_runtime.rs b/src/sql/planning_runtime.rs new file mode 100644 index 00000000..dc4749ad --- /dev/null +++ b/src/sql/planning_runtime.rs @@ -0,0 +1,35 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Runtime-installed SQL planning defaults (from `GlobalConfig` / `conf/config.yaml`). + +use std::sync::OnceLock; + +use crate::config::streaming_job::ResolvedStreamingJobConfig; +use crate::sql::common::constants::sql_planning_default; +use crate::sql::types::SqlConfig; + +static SQL_PLANNING: OnceLock = OnceLock::new(); + +/// Installs [`SqlConfig`] derived from resolved streaming job YAML (KeyBy parallelism, etc.). +/// Safe to call once at bootstrap; later calls are ignored if already set. +pub fn install_sql_planning_from_streaming_job(job: &ResolvedStreamingJobConfig) { + let cfg = SqlConfig { + default_parallelism: sql_planning_default::DEFAULT_PARALLELISM, + key_by_parallelism: job.key_by_parallelism as usize, + }; + let _ = SQL_PLANNING.set(cfg).ok(); +} + +pub(crate) fn sql_planning_snapshot() -> SqlConfig { + SQL_PLANNING.get().cloned().unwrap_or_default() +} From e2f610ba3cb79ab25853742b22bd16a9ea5fc3f3 Mon Sep 17 00:00:00 2001 From: luoluoyuyu Date: Sun, 19 Apr 2026 18:54:08 +0800 Subject: [PATCH 19/26] update --- src/sql/logical_node/logical/operator_chain.rs | 4 +--- src/sql/logical_planner/optimizers/chaining.rs | 9 +++++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/sql/logical_node/logical/operator_chain.rs b/src/sql/logical_node/logical/operator_chain.rs index be8f3b53..2aecddd6 100644 --- a/src/sql/logical_node/logical/operator_chain.rs +++ b/src/sql/logical_node/logical/operator_chain.rs @@ -135,9 +135,7 @@ impl OperatorChain { self.operators.iter().all(|op| { matches!( op.operator_name, - OperatorName::Projection - | OperatorName::Value - | OperatorName::ExpressionWatermark + OperatorName::Projection | OperatorName::Value | OperatorName::ExpressionWatermark ) }) } diff --git a/src/sql/logical_planner/optimizers/chaining.rs b/src/sql/logical_planner/optimizers/chaining.rs index 59287d88..ea7bd885 100644 --- a/src/sql/logical_planner/optimizers/chaining.rs +++ b/src/sql/logical_planner/optimizers/chaining.rs @@ -98,9 +98,7 @@ impl Optimizer for ChainingOptimizer { source_node.description = format!("{} -> {}", source_node.description, target_node.description); - source_node.parallelism = source_node - .parallelism - .max(target_node.parallelism); + source_node.parallelism = source_node.parallelism.max(target_node.parallelism); source_node .operator_chain @@ -176,7 +174,10 @@ mod tests { let changed = ChainingOptimizer {}.optimize_once(&mut g); assert!(changed); assert_eq!(g.node_count(), 2); - let fused = g.node_weights().find(|n| n.description.contains("->")).unwrap(); + let fused = g + .node_weights() + .find(|n| n.description.contains("->")) + .unwrap(); assert_eq!(fused.parallelism, 8); assert_eq!(fused.operator_chain.len(), 2); } From 0f9f4642636f7db62e5f620eb2ea6689dee18b63 Mon Sep 17 00:00:00 2001 From: luoluoyuyu Date: Sun, 19 Apr 2026 23:38:02 +0800 Subject: [PATCH 20/26] update --- protocol/proto/storage.proto | 20 +++ src/runtime/streaming/api/context.rs | 25 +++ src/runtime/streaming/api/operator.rs | 15 +- src/runtime/streaming/api/source.rs | 20 ++- .../streaming/execution/operator_chain.rs | 6 + src/runtime/streaming/execution/pipeline.rs | 2 + .../streaming/execution/source_driver.rs | 17 +- .../streaming/factory/connector/kafka.rs | 8 +- src/runtime/streaming/job/job_manager.rs | 149 ++++++++++++----- .../grouping/incremental_aggregate.rs | 11 +- .../operators/joins/join_instance.rs | 11 +- .../operators/joins/join_with_expiration.rs | 11 +- .../streaming/operators/sink/kafka/mod.rs | 45 ++++++ .../streaming/operators/source/kafka/mod.rs | 140 ++++++++++++++-- .../windows/session_aggregating_window.rs | 11 +- .../windows/sliding_aggregating_window.rs | 11 +- .../windows/tumbling_aggregating_window.rs | 11 +- .../operators/windows/window_function.rs | 11 +- src/runtime/streaming/protocol/control.rs | 24 ++- src/runtime/streaming/state/operator_state.rs | 22 ++- src/storage/stream_catalog/manager.rs | 151 ++++++++++++++++-- src/storage/stream_catalog/mod.rs | 4 +- 22 files changed, 642 insertions(+), 83 deletions(-) diff --git a/protocol/proto/storage.proto b/protocol/proto/storage.proto index 828bbac5..66f3b0f2 100644 --- a/protocol/proto/storage.proto +++ b/protocol/proto/storage.proto @@ -43,6 +43,21 @@ message CatalogSourceTable { // Streaming table storage (CREATE STREAMING TABLE persistence) // ============================================================================= +// Partition offset for one Kafka partition at a completed checkpoint. +message KafkaPartitionOffset { + int32 partition = 1; + int64 offset = 2; +} + +// Kafka source subtask checkpoint: one file / one TaskContext (pipeline + subtask). +message KafkaSourceSubtaskCheckpoint { + uint32 pipeline_id = 1; + uint32 subtask_index = 2; + // Epoch of the barrier when this snapshot was taken (aligns with latest_checkpoint_epoch on commit). + uint64 checkpoint_epoch = 3; + repeated KafkaPartitionOffset partitions = 4; +} + // Persisted record for one streaming table (CREATE STREAMING TABLE). // On restart, the engine re-submits each record to JobManager to resume the pipeline. message StreamingTableDefinition { @@ -58,6 +73,11 @@ message StreamingTableDefinition { // Last globally-committed checkpoint epoch. // Updated by JobManager after all operators ACK. Used for crash recovery. uint64 latest_checkpoint_epoch = 6; + + // Kafka source per-subtask offsets at the same committed epoch as `latest_checkpoint_epoch`. + // Populated by the runtime coordinator from source checkpoint ACKs. Optional `.bin` files under + // the job state dir may exist only for local recovery materialization from this field. + repeated KafkaSourceSubtaskCheckpoint kafka_source_checkpoints = 7; } // ============================================================================= diff --git a/src/runtime/streaming/api/context.rs b/src/runtime/streaming/api/context.rs index b5b723f5..7ce8cf6c 100644 --- a/src/runtime/streaming/api/context.rs +++ b/src/runtime/streaming/api/context.rs @@ -16,9 +16,12 @@ use std::time::{Duration, SystemTime}; use anyhow::{Context, Result, anyhow}; use arrow_array::RecordBatch; +use protocol::storage::KafkaSourceSubtaskCheckpoint; +use tokio::sync::mpsc; use crate::runtime::memory::{MemoryBlock, MemoryPool, get_array_memory_size}; use crate::runtime::streaming::network::endpoint::PhysicalSender; +use crate::runtime::streaming::protocol::control::JobMasterEvent; use crate::runtime::streaming::protocol::event::{StreamEvent, TrackedEvent}; use crate::runtime::streaming::state::IoManager; @@ -74,6 +77,9 @@ pub struct TaskContext { /// Last globally-committed safe epoch for crash recovery. safe_epoch: u64, + + /// When set, pipelines report checkpoint completion (and optional Kafka offsets) to the job coordinator. + checkpoint_ack_tx: Option>, } impl TaskContext { @@ -90,6 +96,7 @@ impl TaskContext { pipeline_state_memory_block: Option>, operator_state_memory_bytes: u64, safe_epoch: u64, + checkpoint_ack_tx: Option>, ) -> Self { let task_name = format!( "Task-[{}]-Pipe[{}]-Sub[{}/{}]", @@ -111,6 +118,7 @@ impl TaskContext { pipeline_state_memory_block, operator_state_memory_bytes, safe_epoch, + checkpoint_ack_tx, } } @@ -119,6 +127,23 @@ impl TaskContext { self.safe_epoch } + /// Notify the job checkpoint coordinator that this pipeline has finished the barrier for `epoch`. + pub async fn send_checkpoint_ack( + &self, + epoch: u64, + kafka_subtask: Option, + ) { + if let Some(tx) = &self.checkpoint_ack_tx { + let _ = tx + .send(JobMasterEvent::CheckpointAck { + pipeline_id: self.pipeline_id, + epoch, + kafka_subtask, + }) + .await; + } + } + #[inline] pub fn config(&self) -> &TaskContextConfig { &self.config diff --git a/src/runtime/streaming/api/operator.rs b/src/runtime/streaming/api/operator.rs index df8f0dcb..53fa629f 100644 --- a/src/runtime/streaming/api/operator.rs +++ b/src/runtime/streaming/api/operator.rs @@ -54,11 +54,24 @@ pub trait Operator: Send + 'static { ctx: &mut TaskContext, ) -> anyhow::Result<()>; + /// Global checkpoint **phase 2** (after metadata is durable): finalize external side effects. + /// + /// Default is no-op. Examples of overrides: transactional Kafka sink calls + /// `commit_transaction` on the producer stashed during [`Self::snapshot_state`]. async fn commit_checkpoint( &mut self, - _epoch: u32, + epoch: u32, _ctx: &mut TaskContext, ) -> anyhow::Result<()> { + let _ = epoch; + Ok(()) + } + + /// Global checkpoint **rollback** when phase 2 must not commit (e.g. catalog persist failed). + /// + /// Default is no-op. Transactional Kafka sink overrides with `abort_transaction` on the stashed producer. + async fn abort_checkpoint(&mut self, epoch: u32, _ctx: &mut TaskContext) -> anyhow::Result<()> { + let _ = epoch; Ok(()) } diff --git a/src/runtime/streaming/api/source.rs b/src/runtime/streaming/api/source.rs index 81435b47..c6597143 100644 --- a/src/runtime/streaming/api/source.rs +++ b/src/runtime/streaming/api/source.rs @@ -14,6 +14,7 @@ use crate::runtime::streaming::api::context::TaskContext; use crate::sql::common::{CheckpointBarrier, Watermark}; use arrow_array::RecordBatch; use async_trait::async_trait; +use protocol::storage::KafkaSourceSubtaskCheckpoint; #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum SourceOffset { @@ -31,6 +32,12 @@ pub enum SourceEvent { EndOfStream, } +/// Optional metadata returned when a source completes a checkpoint barrier snapshot. +#[derive(Debug, Default, Clone)] +pub struct SourceCheckpointReport { + pub kafka_subtask: Option, +} + #[async_trait] pub trait SourceOperator: Send + 'static { fn name(&self) -> &str; @@ -49,13 +56,22 @@ pub trait SourceOperator: Send + 'static { &mut self, barrier: CheckpointBarrier, ctx: &mut TaskContext, - ) -> anyhow::Result<()>; + ) -> anyhow::Result; + /// Same checkpoint **phase 2** hook as [`super::operator::Operator::commit_checkpoint`]. + /// Kafka source keeps the default: offsets are reported at the barrier in [`Self::snapshot_state`]. async fn commit_checkpoint( &mut self, - _epoch: u32, + epoch: u32, _ctx: &mut TaskContext, ) -> anyhow::Result<()> { + let _ = epoch; + Ok(()) + } + + /// Same rollback hook as [`super::operator::Operator::abort_checkpoint`]. + async fn abort_checkpoint(&mut self, epoch: u32, _ctx: &mut TaskContext) -> anyhow::Result<()> { + let _ = epoch; Ok(()) } diff --git a/src/runtime/streaming/execution/operator_chain.rs b/src/runtime/streaming/execution/operator_chain.rs index a2e6c5c6..0b76568b 100644 --- a/src/runtime/streaming/execution/operator_chain.rs +++ b/src/runtime/streaming/execution/operator_chain.rs @@ -158,6 +158,9 @@ impl OperatorDrive for IntermediateDriver { ControlCommand::Commit { epoch } => { self.operator.commit_checkpoint(*epoch, ctx).await?; } + ControlCommand::AbortCheckpoint { epoch } => { + self.operator.abort_checkpoint(*epoch, ctx).await?; + } ControlCommand::Stop { mode } if *mode == StopMode::Immediate => { stop = true; } @@ -273,6 +276,9 @@ impl OperatorDrive for TailDriver { ControlCommand::Commit { epoch } => { self.operator.commit_checkpoint(*epoch, ctx).await?; } + ControlCommand::AbortCheckpoint { epoch } => { + self.operator.abort_checkpoint(*epoch, ctx).await?; + } ControlCommand::Stop { mode } if *mode == StopMode::Immediate => { stop = true; } diff --git a/src/runtime/streaming/execution/pipeline.rs b/src/runtime/streaming/execution/pipeline.rs index d6ef06a3..114399ae 100644 --- a/src/runtime/streaming/execution/pipeline.rs +++ b/src/runtime/streaming/execution/pipeline.rs @@ -110,6 +110,7 @@ impl Pipeline { } } AlignmentStatus::Complete => { + let epoch = barrier.epoch as u64; self.chain_head .process_event( idx, @@ -123,6 +124,7 @@ impl Pipeline { active_streams.insert(i, stream); } } + self.ctx.send_checkpoint_ack(epoch, None).await; } } } diff --git a/src/runtime/streaming/execution/source_driver.rs b/src/runtime/streaming/execution/source_driver.rs index 6813a82a..28bb41ca 100644 --- a/src/runtime/streaming/execution/source_driver.rs +++ b/src/runtime/streaming/execution/source_driver.rs @@ -15,7 +15,7 @@ use tokio::time::{Instant, sleep}; use tracing::{Instrument, info, info_span, warn}; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::source::{SourceEvent, SourceOperator}; +use crate::runtime::streaming::api::source::{SourceCheckpointReport, SourceEvent, SourceOperator}; use crate::runtime::streaming::error::RunError; use crate::runtime::streaming::execution::OperatorDrive; use crate::runtime::streaming::protocol::{ @@ -154,18 +154,25 @@ impl SourceDriver { async fn handle_control(&mut self, cmd: ControlCommand) -> Result { let mut stop = false; + let mut pending_source_checkpoint: Option<(u64, SourceCheckpointReport)> = None; match &cmd { ControlCommand::TriggerCheckpoint { barrier } => { let b: CheckpointBarrier = barrier.clone().into(); - self.operator.snapshot_state(b, &mut self.ctx).await?; + let report = self.operator.snapshot_state(b, &mut self.ctx).await?; self.dispatch_event(StreamEvent::Barrier(b)).await?; + pending_source_checkpoint = Some((b.epoch as u64, report)); } ControlCommand::Commit { epoch } => { self.operator .commit_checkpoint(*epoch, &mut self.ctx) .await?; } + ControlCommand::AbortCheckpoint { epoch } => { + self.operator + .abort_checkpoint(*epoch, &mut self.ctx) + .await?; + } ControlCommand::Stop { .. } => { stop = true; } @@ -178,6 +185,12 @@ impl SourceDriver { stop = true; } + if let Some((epoch, report)) = pending_source_checkpoint { + self.ctx + .send_checkpoint_ack(epoch, report.kafka_subtask) + .await; + } + Ok(stop) } diff --git a/src/runtime/streaming/factory/connector/kafka.rs b/src/runtime/streaming/factory/connector/kafka.rs index 75135197..9d2f114d 100644 --- a/src/runtime/streaming/factory/connector/kafka.rs +++ b/src/runtime/streaming/factory/connector/kafka.rs @@ -200,7 +200,13 @@ impl KafkaConnectorDispatcher { let client_configs = merge_client_configs(&cfg.auth, &cfg.client_configs); let consistency = match cfg.commit_mode() { - KafkaSinkCommitMode::KafkaSinkExactlyOnce => ConsistencyMode::ExactlyOnce, + KafkaSinkCommitMode::KafkaSinkExactlyOnce => { + info!( + topic = %cfg.topic, + "Kafka sink exactly-once: transactional producer + checkpoint 2PC. Downstream Kafka consumers of this topic should set isolation.level=read_committed." + ); + ConsistencyMode::ExactlyOnce + } KafkaSinkCommitMode::KafkaSinkAtLeastOnce => ConsistencyMode::AtLeastOnce, }; diff --git a/src/runtime/streaming/job/job_manager.rs b/src/runtime/streaming/job/job_manager.rs index 549cb314..ea4049aa 100644 --- a/src/runtime/streaming/job/job_manager.rs +++ b/src/runtime/streaming/job/job_manager.rs @@ -23,6 +23,7 @@ use tokio_stream::wrappers::ReceiverStream; use tracing::{debug, error, info, warn}; use protocol::function_stream_graph::{ChainedOperator, FsProgram}; +use protocol::storage::KafkaSourceSubtaskCheckpoint; use crate::config::{ DEFAULT_CHECKPOINT_INTERVAL_MS, DEFAULT_OPERATOR_STATE_STORE_MEMORY_BYTES, @@ -137,6 +138,17 @@ enum PipelineRunner { Standard(Pipeline), } +struct CheckpointCoordinatorConfig { + job_id: String, + source_control_txs: Vec>, + all_pipeline_control_txs: Vec>, + job_master_rx: mpsc::Receiver, + expected_pipeline_ids: HashSet, + interval_ms: u64, + start_epoch: u64, + job_state_dir: PathBuf, +} + impl PipelineRunner { async fn run(self) -> Result<(), crate::runtime::streaming::error::RunError> { match self { @@ -198,6 +210,12 @@ impl JobManager { self.state_config.pipeline_parallelism } + /// Per-job state directory (Kafka offset snapshots, operator state roots, etc.). + #[inline] + pub fn job_state_directory(&self, job_id: &str) -> PathBuf { + self.state_base_dir.join(job_id) + } + pub async fn submit_job( &self, job_id: String, @@ -209,6 +227,7 @@ impl JobManager { let mut pipelines = HashMap::with_capacity(program.nodes.len()); let mut source_control_txs = Vec::new(); + let mut all_pipeline_control_txs = Vec::new(); let mut expected_pipeline_ids = HashSet::new(); let job_state_dir = self.state_base_dir.join(&job_id); @@ -242,6 +261,7 @@ impl JobManager { if is_source { source_control_txs.push(pipeline.control_tx.clone()); } + all_pipeline_control_txs.push(pipeline.control_tx.clone()); expected_pipeline_ids.insert(pipeline_id); pipelines.insert(pipeline_id, pipeline); } @@ -249,14 +269,16 @@ impl JobManager { let interval_ms = custom_checkpoint_interval_ms.unwrap_or(self.state_config.checkpoint_interval_ms); - self.spawn_checkpoint_coordinator( - job_id.clone(), + self.spawn_checkpoint_coordinator(CheckpointCoordinatorConfig { + job_id: job_id.clone(), source_control_txs, + all_pipeline_control_txs, job_master_rx, expected_pipeline_ids, interval_ms, - safe_epoch + 1, - ); + start_epoch: safe_epoch + 1, + job_state_dir: job_state_dir.clone(), + }); let graph = PhysicalExecutionGraph { job_id: job_id.clone(), @@ -457,7 +479,7 @@ impl JobManager { declared_parallelism: u32, edge_manager: &mut EdgeManager, job_state_dir: &Path, - _job_master_tx: mpsc::Sender, + job_master_tx: mpsc::Sender, recovery_epoch: u64, ) -> Result<(PhysicalPipeline, bool)> { let (raw_inboxes, raw_outboxes) = @@ -536,6 +558,7 @@ impl JobManager { pipeline_state_memory_block, per_op, recovery_epoch, + Some(job_master_tx.clone()), ); let runner = if let Some(source) = chain.source { @@ -676,15 +699,17 @@ impl JobManager { // Chandy-Lamport distributed snapshot barrier coordinator // ======================================================================== - fn spawn_checkpoint_coordinator( - &self, - job_id: String, - source_control_txs: Vec>, - mut job_master_rx: mpsc::Receiver, - expected_pipeline_ids: HashSet, - interval_ms: u64, - start_epoch: u64, - ) -> TokioJoinHandle<()> { + fn spawn_checkpoint_coordinator(&self, cfg: CheckpointCoordinatorConfig) -> TokioJoinHandle<()> { + let CheckpointCoordinatorConfig { + job_id, + source_control_txs, + all_pipeline_control_txs, + mut job_master_rx, + expected_pipeline_ids, + interval_ms, + start_epoch, + job_state_dir, + } = cfg; tokio::spawn(async move { if interval_ms == 0 { info!(job_id = %job_id, "Checkpoint disabled for this job"); @@ -696,57 +721,72 @@ impl JobManager { let mut current_epoch: u64 = start_epoch; let mut pending_checkpoints: HashMap> = HashMap::new(); + let mut kafka_reports: HashMap> = HashMap::new(); + + async fn broadcast_checkpoint_phase2( + txs: &[mpsc::Sender], + cmd: ControlCommand, + ) { + for tx in txs { + let _ = tx.send(cmd.clone()).await; + } + } loop { tokio::select! { - _ = interval.tick() => { - info!(job_id = %job_id, epoch = current_epoch, "Triggering global Checkpoint Barrier."); - pending_checkpoints.insert(current_epoch, expected_pipeline_ids.clone()); - - let barrier = CheckpointBarrier { - epoch: current_epoch as u32, - min_epoch: 0, - timestamp: std::time::SystemTime::now(), - then_stop: false, - }; - - for tx in &source_control_txs { - let cmd = ControlCommand::trigger_checkpoint(barrier); - if tx.send(cmd).await.is_err() { - debug!(job_id = %job_id, "Source disconnected. Shutting down coordinator."); - return; - } - } - current_epoch += 1; - } + biased; Some(event) = job_master_rx.recv() => { match event { - JobMasterEvent::CheckpointAck { pipeline_id, epoch } => { + JobMasterEvent::CheckpointAck { + pipeline_id, + epoch, + kafka_subtask, + } => { + if let Some(k) = kafka_subtask { + kafka_reports.entry(epoch).or_default().push(k); + } if let Some(pending_set) = pending_checkpoints.get_mut(&epoch) { pending_set.remove(&pipeline_id); if pending_set.is_empty() { info!( job_id = %job_id, epoch = epoch, - "Checkpoint Epoch is GLOBALLY COMPLETED!" + "Checkpoint Epoch is GLOBALLY COMPLETED (phase 1); persisting metadata and notifying operators (phase 2)" ); + let kf = kafka_reports.remove(&epoch).unwrap_or_default(); + let epoch_u32 = u32::try_from(epoch).unwrap_or(u32::MAX); + + let mut catalog_ok = true; if let Some(catalog) = CatalogManager::try_global() { - if let Err(e) = catalog.commit_job_checkpoint(&job_id, epoch) { + if let Err(e) = catalog.commit_job_checkpoint( + &job_id, + epoch, + &job_state_dir, + kf, + ) { + catalog_ok = false; error!( job_id = %job_id, epoch = epoch, error = %e, - "Failed to commit checkpoint metadata to Catalog" + "Failed to commit checkpoint metadata to Catalog — aborting transactional sinks" ); } } else { warn!( job_id = %job_id, epoch = epoch, - "CatalogManager not available, checkpoint not persisted globally" + "CatalogManager not available; proceeding with operator Commit (Kafka transactional commit) only" ); } + let phase2 = if catalog_ok { + ControlCommand::Commit { epoch: epoch_u32 } + } else { + ControlCommand::AbortCheckpoint { epoch: epoch_u32 } + }; + broadcast_checkpoint_phase2(&all_pipeline_control_txs, phase2).await; + pending_checkpoints.remove(&epoch); } } @@ -756,10 +796,39 @@ impl JobManager { job_id = %job_id, epoch = epoch, pipeline_id = pipeline_id, reason = %reason, "Checkpoint FAILED!" ); - pending_checkpoints.remove(&epoch); + if pending_checkpoints.remove(&epoch).is_some() { + kafka_reports.remove(&epoch); + let epoch_u32 = u32::try_from(epoch).unwrap_or(u32::MAX); + broadcast_checkpoint_phase2( + &all_pipeline_control_txs, + ControlCommand::AbortCheckpoint { epoch: epoch_u32 }, + ) + .await; + } } } } + + _ = interval.tick(), if pending_checkpoints.is_empty() => { + info!(job_id = %job_id, epoch = current_epoch, "Triggering global Checkpoint Barrier."); + pending_checkpoints.insert(current_epoch, expected_pipeline_ids.clone()); + + let barrier = CheckpointBarrier { + epoch: current_epoch as u32, + min_epoch: 0, + timestamp: std::time::SystemTime::now(), + then_stop: false, + }; + + for tx in &source_control_txs { + let cmd = ControlCommand::trigger_checkpoint(barrier); + if tx.send(cmd).await.is_err() { + debug!(job_id = %job_id, "Source disconnected. Shutting down coordinator."); + return; + } + } + current_epoch += 1; + } } } }) diff --git a/src/runtime/streaming/operators/grouping/incremental_aggregate.rs b/src/runtime/streaming/operators/grouping/incremental_aggregate.rs index ffa7e2f1..a8983e99 100644 --- a/src/runtime/streaming/operators/grouping/incremental_aggregate.rs +++ b/src/runtime/streaming/operators/grouping/incremental_aggregate.rs @@ -936,7 +936,7 @@ impl Operator for IncrementalAggregatingFunc { // Flush to Parquet store - .snapshot_epoch(barrier.epoch as u64) + .prepare_checkpoint_epoch(barrier.epoch as u64) .map_err(|e| anyhow!("Snapshot failed: {e}"))?; info!( @@ -949,6 +949,15 @@ impl Operator for IncrementalAggregatingFunc { Ok(()) } + async fn commit_checkpoint(&mut self, epoch: u32, _ctx: &mut TaskContext) -> Result<()> { + self.state_store + .as_ref() + .expect("state store not initialized") + .commit_checkpoint_epoch(epoch as u64) + .map_err(|e| anyhow!("Commit checkpoint failed: {e}"))?; + Ok(()) + } + async fn on_close(&mut self, _ctx: &mut TaskContext) -> Result> { Ok(vec![]) } diff --git a/src/runtime/streaming/operators/joins/join_instance.rs b/src/runtime/streaming/operators/joins/join_instance.rs index a6f4a53f..cddeeff2 100644 --- a/src/runtime/streaming/operators/joins/join_instance.rs +++ b/src/runtime/streaming/operators/joins/join_instance.rs @@ -358,10 +358,19 @@ impl Operator for InstantJoinOperator { self.state_store .as_ref() .unwrap() - .snapshot_epoch(barrier.epoch as u64) + .prepare_checkpoint_epoch(barrier.epoch as u64) .map_err(|e| anyhow!("Snapshot failed: {e}"))?; Ok(()) } + + async fn commit_checkpoint(&mut self, epoch: u32, _ctx: &mut TaskContext) -> Result<()> { + self.state_store + .as_ref() + .unwrap() + .commit_checkpoint_epoch(epoch as u64) + .map_err(|e| anyhow!("Commit checkpoint failed: {e}"))?; + Ok(()) + } } // ============================================================================ diff --git a/src/runtime/streaming/operators/joins/join_with_expiration.rs b/src/runtime/streaming/operators/joins/join_with_expiration.rs index e044f242..92089428 100644 --- a/src/runtime/streaming/operators/joins/join_with_expiration.rs +++ b/src/runtime/streaming/operators/joins/join_with_expiration.rs @@ -312,13 +312,22 @@ impl Operator for JoinWithExpirationOperator { .expect("State store not initialized"); store - .snapshot_epoch(barrier.epoch as u64) + .prepare_checkpoint_epoch(barrier.epoch as u64) .map_err(|e| anyhow!("Snapshot failed: {e}"))?; info!(epoch = barrier.epoch, "Join Operator snapshotted state."); Ok(()) } + async fn commit_checkpoint(&mut self, epoch: u32, _ctx: &mut TaskContext) -> Result<()> { + self.state_store + .as_ref() + .expect("State store not initialized") + .commit_checkpoint_epoch(epoch as u64) + .map_err(|e| anyhow!("Commit checkpoint failed: {e}"))?; + Ok(()) + } + async fn on_close(&mut self, _ctx: &mut TaskContext) -> Result> { Ok(vec![]) } diff --git a/src/runtime/streaming/operators/sink/kafka/mod.rs b/src/runtime/streaming/operators/sink/kafka/mod.rs index a24a098d..bee4d367 100644 --- a/src/runtime/streaming/operators/sink/kafka/mod.rs +++ b/src/runtime/streaming/operators/sink/kafka/mod.rs @@ -10,6 +10,17 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! ## Exactly-once Kafka sink and checkpoint 2PC +//! +//! - **Pre-commit (barrier / `snapshot_state`)**: flush in-flight sends, rotate to a new transactional +//! producer for post-barrier records, and stash the producer that covered this checkpoint interval. +//! - **Commit (`commit_checkpoint`)**: after the job coordinator persists checkpoint metadata (catalog), +//! it broadcasts `ControlCommand::Commit`; this operator calls `commit_transaction` on the stashed +//! producer so consumers with `isolation.level=read_committed` observe the batch. +//! - **Abort (`abort_checkpoint`)**: if metadata commit fails or the checkpoint is declined, the +//! coordinator broadcasts `AbortCheckpoint` and this operator calls `abort_transaction` on the +//! stashed producer. + use anyhow::{Result, anyhow, bail}; use arrow_array::Array; use arrow_array::RecordBatch; @@ -115,6 +126,12 @@ impl KafkaSinkOperator { if let Some(idx) = tx_index { config.set("enable.idempotence", "true"); + if config.get("acks").is_none() { + config.set("acks", "all"); + } + if config.get("transaction.timeout.ms").is_none() { + config.set("transaction.timeout.ms", "600000"); + } let transactional_id = format!( "fs-tx-{}-{}-{}-{}", ctx.job_id, self.topic, ctx.subtask_index, idx @@ -361,6 +378,34 @@ impl Operator for KafkaSinkOperator { Ok(()) } + async fn abort_checkpoint(&mut self, epoch: u32, _ctx: &mut TaskContext) -> Result<()> { + if matches!(self.consistency_mode, ConsistencyMode::AtLeastOnce) { + return Ok(()); + } + + let state = self.transactional_state.as_mut().unwrap(); + let Some(stale) = state.producer_awaiting_commit.take() else { + warn!( + "AbortCheckpoint epoch {} but no stashed transactional producer (already committed or duplicate signal)", + epoch + ); + return Ok(()); + }; + + match stale.abort_transaction(Timeout::After(Duration::from_secs(30))) { + Ok(()) => info!( + "Aborted Kafka transaction for epoch {} (checkpoint metadata did not commit)", + epoch + ), + Err(e) => warn!( + "Kafka abort_transaction for epoch {} returned error (producer dropped): {}", + epoch, e + ), + } + + Ok(()) + } + async fn on_close(&mut self, _ctx: &mut TaskContext) -> Result> { self.flush_to_broker().await?; info!("Kafka sink shut down gracefully."); diff --git a/src/runtime/streaming/operators/source/kafka/mod.rs b/src/runtime/streaming/operators/source/kafka/mod.rs index e73d18fa..3bf65eca 100644 --- a/src/runtime/streaming/operators/source/kafka/mod.rs +++ b/src/runtime/streaming/operators/source/kafka/mod.rs @@ -10,21 +10,28 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Kafka source checkpointing: `enable.auto.commit=false`, offsets captured at the checkpoint barrier +//! and reported to the job coordinator for catalog persistence; restart rewinds from that snapshot. + use anyhow::{Context as _, Result, anyhow}; use arrow_array::RecordBatch; use arrow_schema::SchemaRef; use async_trait::async_trait; use bincode::{Decode, Encode}; use governor::{DefaultDirectRateLimiter, Quota, RateLimiter as GovernorRateLimiter}; +use protocol::storage::{KafkaPartitionOffset, KafkaSourceSubtaskCheckpoint}; use rdkafka::consumer::{CommitMode, Consumer, StreamConsumer}; use rdkafka::{ClientConfig, Message as KMessage, Offset, TopicPartitionList}; use std::collections::HashMap; use std::num::NonZeroU32; +use std::path::PathBuf; use std::time::{Duration, Instant}; use tracing::{debug, error, info, warn}; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::source::{SourceEvent, SourceOffset, SourceOperator}; +use crate::runtime::streaming::api::source::{ + SourceCheckpointReport, SourceEvent, SourceOffset, SourceOperator, +}; use crate::runtime::streaming::format::{BadDataPolicy, DataDeserializer, Format}; use crate::sql::common::fs_schema::FieldValueType; use crate::sql::common::{CheckpointBarrier, MetadataField}; @@ -33,8 +40,74 @@ use crate::sql::common::{CheckpointBarrier, MetadataField}; #[derive(Copy, Clone, Debug, Encode, Decode, PartialEq, PartialOrd)] pub struct KafkaState { - partition: i32, - offset: i64, + pub partition: i32, + pub offset: i64, +} + +/// Last committed partition offsets for this source subtask, tied to a checkpoint epoch. +/// Materialized into a `.bin` under the job state dir from catalog before restart; see +/// [`TaskContext::latest_safe_epoch`] and `StreamingTableDefinition` in `storage.proto`. +#[derive(Debug, Encode, Decode)] +pub(crate) struct KafkaSourceSavedOffsets { + /// Same numbering as [`CheckpointBarrier::epoch`] / catalog `latest_checkpoint_epoch` (as u64). + pub(crate) epoch: u64, + pub(crate) partitions: Vec, +} + +pub(crate) fn encode_kafka_offset_snapshot(saved: &KafkaSourceSavedOffsets) -> Result> { + bincode::encode_to_vec(saved, bincode::config::standard()) + .map_err(|e| anyhow!("bincode encode Kafka offset snapshot: {e}")) +} + +pub(crate) fn decode_kafka_offset_snapshot(bytes: &[u8]) -> Result { + let (saved, _) = bincode::decode_from_slice(bytes, bincode::config::standard()) + .map_err(|e| anyhow!("bincode decode Kafka offset snapshot: {e}"))?; + Ok(saved) +} + +pub(crate) fn kafka_snapshot_path( + job_dir: &std::path::Path, + pipeline_id: u32, + subtask_index: u32, +) -> PathBuf { + job_dir.join(format!( + "kafka_source_offsets_pipe{}_sub{}.bin", + pipeline_id, subtask_index + )) +} + +fn kafka_offsets_snapshot_path(ctx: &TaskContext) -> PathBuf { + kafka_snapshot_path(&ctx.state_dir, ctx.pipeline_id, ctx.subtask_index) +} + +fn load_saved_offsets_if_recovering(ctx: &TaskContext) -> Option { + let safe = ctx.latest_safe_epoch(); + if safe == 0 { + return None; + } + let path = kafka_offsets_snapshot_path(ctx); + let bytes = std::fs::read(&path).ok()?; + let saved = match decode_kafka_offset_snapshot(&bytes) { + Ok(v) => v, + Err(e) => { + warn!( + path = %path.display(), + error = %e, + "Failed to decode Kafka offset snapshot" + ); + return None; + } + }; + if saved.epoch > safe { + warn!( + path = %path.display(), + saved_epoch = saved.epoch, + safe_epoch = safe, + "Ignoring Kafka offset snapshot newer than catalog safe epoch" + ); + return None; + } + Some(saved) } pub trait BatchDeserializer: Send + 'static { @@ -182,7 +255,11 @@ impl KafkaSourceOperator { } } - async fn init_and_assign_consumer(&mut self, ctx: &mut TaskContext) -> Result<()> { + async fn init_and_assign_consumer( + &mut self, + ctx: &mut TaskContext, + saved_offsets: Option, + ) -> Result<()> { info!("Creating kafka consumer for {}", self.bootstrap_servers); let mut client_config = ClientConfig::new(); @@ -205,8 +282,24 @@ impl KafkaSourceOperator { .set("group.id", &group_id) .create()?; - let has_state = false; - let state_map: HashMap = HashMap::new(); + let (has_state, state_map) = if let Some(saved) = saved_offsets { + info!( + job_id = %ctx.job_id, + pipeline_id = ctx.pipeline_id, + subtask = ctx.subtask_index, + epoch = saved.epoch, + safe_epoch = ctx.latest_safe_epoch(), + partitions = saved.partitions.len(), + "Restoring Kafka source offsets from materialized checkpoint snapshot" + ); + let mut m = HashMap::with_capacity(saved.partitions.len()); + for s in saved.partitions { + m.insert(s.partition, s); + } + (true, m) + } else { + (false, HashMap::new()) + }; let metadata = consumer .fetch_metadata(Some(&self.topic), Duration::from_secs(30)) @@ -224,9 +317,10 @@ impl KafkaSourceOperator { for p in partitions { if p.id().rem_euclid(pmax) == ctx.subtask_index as i32 { + // `current_offsets` / snapshot store last consumed offset; resume at next offset. let offset = state_map .get(&p.id()) - .map(|s| Offset::Offset(s.offset)) + .map(|s| Offset::Offset(s.offset.saturating_add(1))) .unwrap_or_else(|| { if has_state { Offset::Beginning @@ -264,7 +358,8 @@ impl SourceOperator for KafkaSourceOperator { } async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { - self.init_and_assign_consumer(ctx).await?; + let saved = load_saved_offsets_if_recovering(ctx); + self.init_and_assign_consumer(ctx, saved).await?; self.rate_limiter = Some(GovernorRateLimiter::direct(Quota::per_second( self.messages_per_second, ))); @@ -363,10 +458,13 @@ impl SourceOperator for KafkaSourceOperator { async fn snapshot_state( &mut self, - _barrier: CheckpointBarrier, + barrier: CheckpointBarrier, ctx: &mut TaskContext, - ) -> Result<()> { - debug!("Source [{}] executing checkpoint", ctx.subtask_index); + ) -> Result { + debug!( + "Source [{}] executing checkpoint epoch {}", + ctx.subtask_index, barrier.epoch + ); let mut topic_partitions = TopicPartitionList::new(); for (&partition, &offset) in &self.current_offsets { @@ -381,7 +479,25 @@ impl SourceOperator for KafkaSourceOperator { warn!("Failed to commit async offset to Kafka Broker: {:?}", e); } - Ok(()) + let epoch = u64::from(barrier.epoch); + let kafka_subtask = if self.current_offsets.is_empty() { + None + } else { + let mut parts: Vec<(i32, i64)> = + self.current_offsets.iter().map(|(&p, &o)| (p, o)).collect(); + parts.sort_by_key(|x| x.0); + Some(KafkaSourceSubtaskCheckpoint { + pipeline_id: ctx.pipeline_id, + subtask_index: ctx.subtask_index, + checkpoint_epoch: epoch, + partitions: parts + .into_iter() + .map(|(partition, offset)| KafkaPartitionOffset { partition, offset }) + .collect(), + }) + }; + + Ok(SourceCheckpointReport { kafka_subtask }) } async fn on_close(&mut self, _ctx: &mut TaskContext) -> Result<()> { diff --git a/src/runtime/streaming/operators/windows/session_aggregating_window.rs b/src/runtime/streaming/operators/windows/session_aggregating_window.rs index 37be0b04..ad32f73f 100644 --- a/src/runtime/streaming/operators/windows/session_aggregating_window.rs +++ b/src/runtime/streaming/operators/windows/session_aggregating_window.rs @@ -866,7 +866,7 @@ impl Operator for SessionWindowOperator { self.state_store .as_ref() .expect("State store not initialized") - .snapshot_epoch(barrier.epoch as u64) + .prepare_checkpoint_epoch(barrier.epoch as u64) .map_err(|e| anyhow!("Snapshot failed: {e}"))?; info!( @@ -876,6 +876,15 @@ impl Operator for SessionWindowOperator { Ok(()) } + async fn commit_checkpoint(&mut self, epoch: u32, _ctx: &mut TaskContext) -> Result<()> { + self.state_store + .as_ref() + .expect("State store not initialized") + .commit_checkpoint_epoch(epoch as u64) + .map_err(|e| anyhow!("Commit checkpoint failed: {e}"))?; + Ok(()) + } + async fn on_close(&mut self, _ctx: &mut TaskContext) -> Result> { Ok(vec![]) } diff --git a/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs b/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs index 02666c03..64d09b8d 100644 --- a/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs +++ b/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs @@ -611,11 +611,20 @@ impl Operator for SlidingWindowOperator { self.state_store .as_ref() .expect("State store not initialized") - .snapshot_epoch(barrier.epoch as u64) + .prepare_checkpoint_epoch(barrier.epoch as u64) .map_err(|e| anyhow!("Snapshot failed: {e}"))?; Ok(()) } + async fn commit_checkpoint(&mut self, epoch: u32, _ctx: &mut TaskContext) -> Result<()> { + self.state_store + .as_ref() + .expect("State store not initialized") + .commit_checkpoint_epoch(epoch as u64) + .map_err(|e| anyhow!("Commit checkpoint failed: {e}"))?; + Ok(()) + } + async fn on_close(&mut self, _ctx: &mut TaskContext) -> Result> { Ok(vec![]) } diff --git a/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs b/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs index f4a17fd1..4e48c50c 100644 --- a/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs +++ b/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs @@ -389,11 +389,20 @@ impl Operator for TumblingWindowOperator { self.state_store .as_ref() .expect("State store not initialized") - .snapshot_epoch(barrier.epoch as u64) + .prepare_checkpoint_epoch(barrier.epoch as u64) .map_err(|e| anyhow!("Snapshot failed: {e}"))?; Ok(()) } + async fn commit_checkpoint(&mut self, epoch: u32, _ctx: &mut TaskContext) -> Result<()> { + self.state_store + .as_ref() + .expect("State store not initialized") + .commit_checkpoint_epoch(epoch as u64) + .map_err(|e| anyhow!("Commit checkpoint failed: {e}"))?; + Ok(()) + } + async fn on_close(&mut self, _ctx: &mut TaskContext) -> Result> { Ok(vec![]) } diff --git a/src/runtime/streaming/operators/windows/window_function.rs b/src/runtime/streaming/operators/windows/window_function.rs index a379cd2d..585e51bb 100644 --- a/src/runtime/streaming/operators/windows/window_function.rs +++ b/src/runtime/streaming/operators/windows/window_function.rs @@ -269,11 +269,20 @@ impl Operator for WindowFunctionOperator { self.state_store .as_ref() .expect("State store not initialized") - .snapshot_epoch(barrier.epoch as u64) + .prepare_checkpoint_epoch(barrier.epoch as u64) .map_err(|e| anyhow!("Snapshot failed: {e}"))?; Ok(()) } + async fn commit_checkpoint(&mut self, epoch: u32, _ctx: &mut TaskContext) -> Result<()> { + self.state_store + .as_ref() + .expect("State store not initialized") + .commit_checkpoint_epoch(epoch as u64) + .map_err(|e| anyhow!("Commit checkpoint failed: {e}"))?; + Ok(()) + } + async fn on_close(&mut self, _ctx: &mut TaskContext) -> Result> { Ok(vec![]) } diff --git a/src/runtime/streaming/protocol/control.rs b/src/runtime/streaming/protocol/control.rs index e87ccd3b..08719e97 100644 --- a/src/runtime/streaming/protocol/control.rs +++ b/src/runtime/streaming/protocol/control.rs @@ -11,6 +11,7 @@ // limitations under the License. use super::event::CheckpointBarrier; +use protocol::storage::KafkaSourceSubtaskCheckpoint; use serde::{Deserialize, Serialize}; use std::time::Duration; use tokio::sync::mpsc::{self, Receiver, Sender}; @@ -55,11 +56,24 @@ impl From for CheckpointBarrier { #[derive(Debug, Clone, Serialize, Deserialize)] pub enum ControlCommand { Start, - Stop { mode: StopMode }, + Stop { + mode: StopMode, + }, DropState, - Commit { epoch: u32 }, - UpdateConfig { config_json: String }, - TriggerCheckpoint { barrier: CheckpointBarrierWire }, + /// Phase 2 of checkpoint 2PC: metadata durable; transactional Kafka sink should `commit_transaction`. + Commit { + epoch: u32, + }, + /// Roll back pre-committed transactional Kafka writes when checkpoint metadata commit failed or barrier declined. + AbortCheckpoint { + epoch: u32, + }, + UpdateConfig { + config_json: String, + }, + TriggerCheckpoint { + barrier: CheckpointBarrierWire, + }, } impl ControlCommand { @@ -85,6 +99,8 @@ pub enum JobMasterEvent { CheckpointAck { pipeline_id: u32, epoch: u64, + /// Kafka source subtask progress at this barrier (only source pipelines set this). + kafka_subtask: Option, }, CheckpointDecline { pipeline_id: u32, diff --git a/src/runtime/streaming/state/operator_state.rs b/src/runtime/streaming/state/operator_state.rs index 39de499d..a3514461 100644 --- a/src/runtime/streaming/state/operator_state.rs +++ b/src/runtime/streaming/state/operator_state.rs @@ -220,14 +220,34 @@ impl OperatorStateStore { Ok(()) } - pub fn snapshot_epoch(self: &Arc, epoch: u64) -> Result<()> { + /// Checkpoint phase 1: flush mutable in-memory state into an epoch-tagged immutable table and + /// trigger spill. This does NOT advance `current_epoch`. + pub fn prepare_checkpoint_epoch(self: &Arc, epoch: u64) -> Result<()> { self.downgrade_active_table(epoch); self.trigger_spill(); + Ok(()) + } + + /// Checkpoint phase 2: once global metadata commit succeeds, advance the durable safe epoch. + pub fn commit_checkpoint_epoch(self: &Arc, epoch: u64) -> Result<()> { self.current_epoch .store(epoch.saturating_add(1), Ordering::Release); Ok(()) } + /// Checkpoint rollback: do not advance `current_epoch`. Any already-spilled files are kept and + /// filtered by safe epoch during restore. + pub fn abort_checkpoint_epoch(self: &Arc, _epoch: u64) -> Result<()> { + Ok(()) + } + + /// Backward-compatible helper (phase1 + phase2 in one call). + pub fn snapshot_epoch(self: &Arc, epoch: u64) -> Result<()> { + self.prepare_checkpoint_epoch(epoch)?; + self.commit_checkpoint_epoch(epoch)?; + Ok(()) + } + pub async fn await_spill_complete(&self) { while self.is_spilling.load(Ordering::SeqCst) { self.spill_notify.notified().await; diff --git a/src/storage/stream_catalog/manager.rs b/src/storage/stream_catalog/manager.rs index 471e3cd9..3c9d561e 100644 --- a/src/storage/stream_catalog/manager.rs +++ b/src/storage/stream_catalog/manager.rs @@ -10,6 +10,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::path::Path; use std::sync::{Arc, OnceLock}; use anyhow::{Context, anyhow, bail}; @@ -18,6 +19,8 @@ use prost::Message; use protocol::function_stream_graph::FsProgram; use protocol::storage::{self as pb, table_definition}; use tracing::{debug, info, warn}; + +use crate::runtime::streaming::operators::source::kafka as kafka_snap; use unicase::UniCase; use crate::sql::common::constants::sql_field; @@ -33,6 +36,93 @@ use super::meta_store::MetaStore; const CATALOG_KEY_PREFIX: &str = "catalog:stream_table:"; const STREAMING_JOB_KEY_PREFIX: &str = "streaming_job:"; +/// One persisted streaming job row from catalog (program + checkpoint metadata + Kafka offsets). +#[derive(Debug, Clone)] +pub struct StoredStreamingJob { + pub table_name: String, + pub program: FsProgram, + pub checkpoint_interval_ms: u64, + pub latest_checkpoint_epoch: u64, + pub kafka_source_checkpoints: Vec, +} + +fn parse_kafka_offset_snapshot_filename(name: &str) -> Option<(u32, u32)> { + const PREFIX: &str = "kafka_source_offsets_pipe"; + const SUFFIX: &str = ".bin"; + if !name.starts_with(PREFIX) || !name.ends_with(SUFFIX) { + return None; + } + let mid = name.strip_prefix(PREFIX)?.strip_suffix(SUFFIX)?; + let (pipe, sub_part) = mid.split_once("_sub")?; + Some((pipe.parse().ok()?, sub_part.parse().ok()?)) +} + +/// Removes on-disk staging snapshots once their payload is committed into catalog (same epoch). +fn cleanup_kafka_offset_snapshots_for_epoch(job_dir: &Path, epoch: u64) { + let Ok(rd) = std::fs::read_dir(job_dir) else { + return; + }; + for ent in rd.flatten() { + let path = ent.path(); + let name = ent.file_name().to_string_lossy().into_owned(); + if parse_kafka_offset_snapshot_filename(&name).is_none() { + continue; + } + let Ok(bytes) = std::fs::read(&path) else { + continue; + }; + let Ok(saved) = kafka_snap::decode_kafka_offset_snapshot(&bytes) else { + continue; + }; + if saved.epoch == epoch && std::fs::remove_file(&path).is_err() { + debug!(path = %path.display(), "Could not remove staged Kafka offset snapshot (non-fatal)"); + } + } +} + +/// Writes catalog-stored Kafka checkpoints back to the job state dir before `submit_job` resumes sources. +pub fn materialize_kafka_source_checkpoints_from_catalog( + job_dir: &Path, + checkpoints: &[pb::KafkaSourceSubtaskCheckpoint], +) -> DFResult<()> { + if checkpoints.is_empty() { + return Ok(()); + } + std::fs::create_dir_all(job_dir).map_err(|e| { + datafusion::common::DataFusionError::Execution(format!( + "create job state dir {}: {e}", + job_dir.display() + )) + })?; + for c in checkpoints { + let saved = kafka_snap::KafkaSourceSavedOffsets { + epoch: c.checkpoint_epoch, + partitions: c + .partitions + .iter() + .map(|p| kafka_snap::KafkaState { + partition: p.partition, + offset: p.offset, + }) + .collect(), + }; + let path = kafka_snap::kafka_snapshot_path(job_dir, c.pipeline_id, c.subtask_index); + let bytes = kafka_snap::encode_kafka_offset_snapshot(&saved).map_err(|e| { + datafusion::common::DataFusionError::Execution(format!( + "encode kafka snapshot for {}: {e}", + path.display() + )) + })?; + std::fs::write(&path, &bytes).map_err(|e| { + datafusion::common::DataFusionError::Execution(format!( + "write kafka snapshot {}: {e}", + path.display() + )) + })?; + } + Ok(()) +} + pub struct CatalogManager { store: Arc, } @@ -98,6 +188,7 @@ impl CatalogManager { comment: comment.to_string(), checkpoint_interval_ms, latest_checkpoint_epoch: 0, + kafka_source_checkpoints: vec![], }; let payload = def.encode_to_vec(); let key = Self::build_streaming_job_key(table_name); @@ -115,7 +206,18 @@ impl CatalogManager { /// Persist the globally-completed checkpoint epoch after all operators ACK. /// Only advances forward; stale epochs are silently ignored. - pub fn commit_job_checkpoint(&self, table_name: &str, epoch: u64) -> DFResult<()> { + /// + /// `kafka_source_checkpoints` is assembled by the job coordinator from source pipeline checkpoint + /// ACKs (in-memory); it is stored next to `latest_checkpoint_epoch` in the catalog. + /// + /// `job_state_dir` is only used to remove legacy on-disk staging snapshots for this epoch, if present. + pub fn commit_job_checkpoint( + &self, + table_name: &str, + epoch: u64, + job_state_dir: &Path, + kafka_source_checkpoints: Vec, + ) -> DFResult<()> { let key = Self::build_streaming_job_key(table_name); let current_payload = self.store.get(&key)?.ok_or_else(|| { @@ -135,16 +237,23 @@ impl CatalogManager { if epoch > def.latest_checkpoint_epoch { def.latest_checkpoint_epoch = epoch; + def.kafka_source_checkpoints = kafka_source_checkpoints; let new_payload = def.encode_to_vec(); self.store.put(&key, new_payload)?; - debug!(table = %table_name, epoch = epoch, "Checkpoint metadata committed to Catalog"); + debug!( + table = %table_name, + epoch = epoch, + kafka_subtasks = def.kafka_source_checkpoints.len(), + "Checkpoint metadata committed to Catalog" + ); + cleanup_kafka_offset_snapshots_for_epoch(job_state_dir, epoch); } Ok(()) } - /// Returns (table_name, program, checkpoint_interval_ms, latest_checkpoint_epoch). - pub fn load_streaming_job_definitions(&self) -> DFResult> { + /// Load all persisted streaming jobs (including Kafka offset checkpoints for restore). + pub fn load_streaming_job_definitions(&self) -> DFResult> { let records = self.store.scan_prefix(STREAMING_JOB_KEY_PREFIX)?; let mut out = Vec::with_capacity(records.len()); for (key, payload) in records { @@ -170,12 +279,13 @@ impl CatalogManager { continue; } }; - out.push(( - def.table_name, + out.push(StoredStreamingJob { + table_name: def.table_name, program, - def.checkpoint_interval_ms, - def.latest_checkpoint_epoch, - )); + checkpoint_interval_ms: def.checkpoint_interval_ms, + latest_checkpoint_epoch: def.latest_checkpoint_epoch, + kafka_source_checkpoints: def.kafka_source_checkpoints, + }); } Ok(out) } @@ -561,10 +671,28 @@ pub fn restore_streaming_jobs_from_store() { let mut restored = 0usize; let mut failed = 0usize; - for (table_name, fs_program, interval_ms, latest_epoch) in definitions { + for job in definitions { + let StoredStreamingJob { + table_name, + program, + checkpoint_interval_ms: interval_ms, + latest_checkpoint_epoch: latest_epoch, + kafka_source_checkpoints, + } = job; let jm = job_manager.clone(); let name = table_name.clone(); + let job_dir = jm.job_state_directory(&table_name); + if let Err(e) = + materialize_kafka_source_checkpoints_from_catalog(&job_dir, &kafka_source_checkpoints) + { + warn!( + table = %table_name, + error = %e, + "Failed to materialize Kafka checkpoints from catalog before job restore" + ); + } + let custom_interval = if interval_ms > 0 { Some(interval_ms) } else { @@ -576,8 +704,7 @@ pub fn restore_streaming_jobs_from_store() { None }; - match rt.block_on(jm.submit_job(name.clone(), fs_program, custom_interval, recovery_epoch)) - { + match rt.block_on(jm.submit_job(name.clone(), program, custom_interval, recovery_epoch)) { Ok(job_id) => { info!( table = %table_name, job_id = %job_id, diff --git a/src/storage/stream_catalog/mod.rs b/src/storage/stream_catalog/mod.rs index 6f31317a..ef176c40 100644 --- a/src/storage/stream_catalog/mod.rs +++ b/src/storage/stream_catalog/mod.rs @@ -17,8 +17,10 @@ mod manager; mod meta_store; mod rocksdb_meta_store; +#[allow(unused_imports)] pub use manager::{ - CatalogManager, initialize_stream_catalog, restore_global_catalog_from_store, + CatalogManager, StoredStreamingJob, initialize_stream_catalog, + materialize_kafka_source_checkpoints_from_catalog, restore_global_catalog_from_store, restore_streaming_jobs_from_store, }; pub use meta_store::{InMemoryMetaStore, MetaStore}; From 1b7e910f2be3e19af5c7d4c7529f5e6c01e10735 Mon Sep 17 00:00:00 2001 From: luoluoyuyu Date: Sun, 19 Apr 2026 23:54:51 +0800 Subject: [PATCH 21/26] update --- protocol/proto/storage.proto | 7 ++++ src/runtime/streaming/api/context.rs | 6 +-- src/runtime/streaming/api/source.rs | 14 ++++++- src/runtime/streaming/execution/pipeline.rs | 2 +- .../streaming/execution/source_driver.rs | 2 +- src/runtime/streaming/job/job_manager.rs | 39 +++++++++++++++---- .../streaming/operators/source/kafka/mod.rs | 14 ++++--- src/runtime/streaming/protocol/control.rs | 6 +-- 8 files changed, 67 insertions(+), 23 deletions(-) diff --git a/protocol/proto/storage.proto b/protocol/proto/storage.proto index 66f3b0f2..fd021727 100644 --- a/protocol/proto/storage.proto +++ b/protocol/proto/storage.proto @@ -58,6 +58,13 @@ message KafkaSourceSubtaskCheckpoint { repeated KafkaPartitionOffset partitions = 4; } +// Generic source checkpoint payload envelope (enum-like via oneof). +message SourceCheckpointPayload { + oneof checkpoint { + KafkaSourceSubtaskCheckpoint kafka = 1; + } +} + // Persisted record for one streaming table (CREATE STREAMING TABLE). // On restart, the engine re-submits each record to JobManager to resume the pipeline. message StreamingTableDefinition { diff --git a/src/runtime/streaming/api/context.rs b/src/runtime/streaming/api/context.rs index 7ce8cf6c..8b778502 100644 --- a/src/runtime/streaming/api/context.rs +++ b/src/runtime/streaming/api/context.rs @@ -16,7 +16,7 @@ use std::time::{Duration, SystemTime}; use anyhow::{Context, Result, anyhow}; use arrow_array::RecordBatch; -use protocol::storage::KafkaSourceSubtaskCheckpoint; +use protocol::storage::SourceCheckpointPayload; use tokio::sync::mpsc; use crate::runtime::memory::{MemoryBlock, MemoryPool, get_array_memory_size}; @@ -131,14 +131,14 @@ impl TaskContext { pub async fn send_checkpoint_ack( &self, epoch: u64, - kafka_subtask: Option, + source_payloads: Vec, ) { if let Some(tx) = &self.checkpoint_ack_tx { let _ = tx .send(JobMasterEvent::CheckpointAck { pipeline_id: self.pipeline_id, epoch, - kafka_subtask, + source_payloads, }) .await; } diff --git a/src/runtime/streaming/api/source.rs b/src/runtime/streaming/api/source.rs index c6597143..8c1be3db 100644 --- a/src/runtime/streaming/api/source.rs +++ b/src/runtime/streaming/api/source.rs @@ -14,7 +14,7 @@ use crate::runtime::streaming::api::context::TaskContext; use crate::sql::common::{CheckpointBarrier, Watermark}; use arrow_array::RecordBatch; use async_trait::async_trait; -use protocol::storage::KafkaSourceSubtaskCheckpoint; +use protocol::storage::{KafkaSourceSubtaskCheckpoint, SourceCheckpointPayload, source_checkpoint_payload}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum SourceOffset { @@ -35,7 +35,17 @@ pub enum SourceEvent { /// Optional metadata returned when a source completes a checkpoint barrier snapshot. #[derive(Debug, Default, Clone)] pub struct SourceCheckpointReport { - pub kafka_subtask: Option, + pub payloads: Vec, +} + +impl SourceCheckpointReport { + pub fn from_kafka_checkpoint(kafka: KafkaSourceSubtaskCheckpoint) -> Self { + Self { + payloads: vec![SourceCheckpointPayload { + checkpoint: Some(source_checkpoint_payload::Checkpoint::Kafka(kafka)), + }], + } + } } #[async_trait] diff --git a/src/runtime/streaming/execution/pipeline.rs b/src/runtime/streaming/execution/pipeline.rs index 114399ae..a9a2d102 100644 --- a/src/runtime/streaming/execution/pipeline.rs +++ b/src/runtime/streaming/execution/pipeline.rs @@ -124,7 +124,7 @@ impl Pipeline { active_streams.insert(i, stream); } } - self.ctx.send_checkpoint_ack(epoch, None).await; + self.ctx.send_checkpoint_ack(epoch, vec![]).await; } } } diff --git a/src/runtime/streaming/execution/source_driver.rs b/src/runtime/streaming/execution/source_driver.rs index 28bb41ca..6fc0c9df 100644 --- a/src/runtime/streaming/execution/source_driver.rs +++ b/src/runtime/streaming/execution/source_driver.rs @@ -187,7 +187,7 @@ impl SourceDriver { if let Some((epoch, report)) = pending_source_checkpoint { self.ctx - .send_checkpoint_ack(epoch, report.kafka_subtask) + .send_checkpoint_ack(epoch, report.payloads) .await; } diff --git a/src/runtime/streaming/job/job_manager.rs b/src/runtime/streaming/job/job_manager.rs index ea4049aa..2d828a5c 100644 --- a/src/runtime/streaming/job/job_manager.rs +++ b/src/runtime/streaming/job/job_manager.rs @@ -23,7 +23,9 @@ use tokio_stream::wrappers::ReceiverStream; use tracing::{debug, error, info, warn}; use protocol::function_stream_graph::{ChainedOperator, FsProgram}; -use protocol::storage::KafkaSourceSubtaskCheckpoint; +use protocol::storage::{ + KafkaSourceSubtaskCheckpoint, SourceCheckpointPayload, source_checkpoint_payload, +}; use crate::config::{ DEFAULT_CHECKPOINT_INTERVAL_MS, DEFAULT_OPERATOR_STATE_STORE_MEMORY_BYTES, @@ -158,6 +160,25 @@ impl PipelineRunner { } } +fn decode_kafka_checkpoints_from_source_payloads( + payloads: Vec, + epoch: u64, +) -> Vec { + let mut out = Vec::new(); + for p in payloads { + match p.checkpoint { + Some(source_checkpoint_payload::Checkpoint::Kafka(mut cp)) => { + if cp.checkpoint_epoch != epoch { + cp.checkpoint_epoch = epoch; + } + out.push(cp); + } + None => warn!("Skip empty source checkpoint payload"), + } + } + out +} + impl JobManager { pub fn new( operator_factory: Arc, @@ -721,7 +742,7 @@ impl JobManager { let mut current_epoch: u64 = start_epoch; let mut pending_checkpoints: HashMap> = HashMap::new(); - let mut kafka_reports: HashMap> = HashMap::new(); + let mut source_reports: HashMap> = HashMap::new(); async fn broadcast_checkpoint_phase2( txs: &[mpsc::Sender], @@ -741,10 +762,13 @@ impl JobManager { JobMasterEvent::CheckpointAck { pipeline_id, epoch, - kafka_subtask, + source_payloads, } => { - if let Some(k) = kafka_subtask { - kafka_reports.entry(epoch).or_default().push(k); + if !source_payloads.is_empty() { + source_reports + .entry(epoch) + .or_default() + .extend(source_payloads); } if let Some(pending_set) = pending_checkpoints.get_mut(&epoch) { pending_set.remove(&pipeline_id); @@ -755,7 +779,8 @@ impl JobManager { "Checkpoint Epoch is GLOBALLY COMPLETED (phase 1); persisting metadata and notifying operators (phase 2)" ); - let kf = kafka_reports.remove(&epoch).unwrap_or_default(); + let payloads = source_reports.remove(&epoch).unwrap_or_default(); + let kf = decode_kafka_checkpoints_from_source_payloads(payloads, epoch); let epoch_u32 = u32::try_from(epoch).unwrap_or(u32::MAX); let mut catalog_ok = true; @@ -797,7 +822,7 @@ impl JobManager { reason = %reason, "Checkpoint FAILED!" ); if pending_checkpoints.remove(&epoch).is_some() { - kafka_reports.remove(&epoch); + source_reports.remove(&epoch); let epoch_u32 = u32::try_from(epoch).unwrap_or(u32::MAX); broadcast_checkpoint_phase2( &all_pipeline_control_txs, diff --git a/src/runtime/streaming/operators/source/kafka/mod.rs b/src/runtime/streaming/operators/source/kafka/mod.rs index 3bf65eca..9f5b84ad 100644 --- a/src/runtime/streaming/operators/source/kafka/mod.rs +++ b/src/runtime/streaming/operators/source/kafka/mod.rs @@ -480,13 +480,15 @@ impl SourceOperator for KafkaSourceOperator { } let epoch = u64::from(barrier.epoch); - let kafka_subtask = if self.current_offsets.is_empty() { - None - } else { + if self.current_offsets.is_empty() { + return Ok(SourceCheckpointReport::default()); + } + + let kafka_subtask = { let mut parts: Vec<(i32, i64)> = self.current_offsets.iter().map(|(&p, &o)| (p, o)).collect(); parts.sort_by_key(|x| x.0); - Some(KafkaSourceSubtaskCheckpoint { + KafkaSourceSubtaskCheckpoint { pipeline_id: ctx.pipeline_id, subtask_index: ctx.subtask_index, checkpoint_epoch: epoch, @@ -494,10 +496,10 @@ impl SourceOperator for KafkaSourceOperator { .into_iter() .map(|(partition, offset)| KafkaPartitionOffset { partition, offset }) .collect(), - }) + } }; - Ok(SourceCheckpointReport { kafka_subtask }) + Ok(SourceCheckpointReport::from_kafka_checkpoint(kafka_subtask)) } async fn on_close(&mut self, _ctx: &mut TaskContext) -> Result<()> { diff --git a/src/runtime/streaming/protocol/control.rs b/src/runtime/streaming/protocol/control.rs index 08719e97..6d0bc492 100644 --- a/src/runtime/streaming/protocol/control.rs +++ b/src/runtime/streaming/protocol/control.rs @@ -11,7 +11,7 @@ // limitations under the License. use super::event::CheckpointBarrier; -use protocol::storage::KafkaSourceSubtaskCheckpoint; +use protocol::storage::SourceCheckpointPayload; use serde::{Deserialize, Serialize}; use std::time::Duration; use tokio::sync::mpsc::{self, Receiver, Sender}; @@ -99,8 +99,8 @@ pub enum JobMasterEvent { CheckpointAck { pipeline_id: u32, epoch: u64, - /// Kafka source subtask progress at this barrier (only source pipelines set this). - kafka_subtask: Option, + /// Source protocol checkpoint payloads (enum-style oneof envelope). + source_payloads: Vec, }, CheckpointDecline { pipeline_id: u32, From 272826f998c0e9773fe1c6cf1efed743a58f83c4 Mon Sep 17 00:00:00 2001 From: luoluoyuyu Date: Mon, 20 Apr 2026 00:18:07 +0800 Subject: [PATCH 22/26] update --- src/runtime/streaming/api/source.rs | 4 +++- src/runtime/streaming/execution/source_driver.rs | 4 +--- src/runtime/streaming/job/job_manager.rs | 5 ++++- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/runtime/streaming/api/source.rs b/src/runtime/streaming/api/source.rs index 8c1be3db..26851eb2 100644 --- a/src/runtime/streaming/api/source.rs +++ b/src/runtime/streaming/api/source.rs @@ -14,7 +14,9 @@ use crate::runtime::streaming::api::context::TaskContext; use crate::sql::common::{CheckpointBarrier, Watermark}; use arrow_array::RecordBatch; use async_trait::async_trait; -use protocol::storage::{KafkaSourceSubtaskCheckpoint, SourceCheckpointPayload, source_checkpoint_payload}; +use protocol::storage::{ + KafkaSourceSubtaskCheckpoint, SourceCheckpointPayload, source_checkpoint_payload, +}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum SourceOffset { diff --git a/src/runtime/streaming/execution/source_driver.rs b/src/runtime/streaming/execution/source_driver.rs index 6fc0c9df..9f403053 100644 --- a/src/runtime/streaming/execution/source_driver.rs +++ b/src/runtime/streaming/execution/source_driver.rs @@ -186,9 +186,7 @@ impl SourceDriver { } if let Some((epoch, report)) = pending_source_checkpoint { - self.ctx - .send_checkpoint_ack(epoch, report.payloads) - .await; + self.ctx.send_checkpoint_ack(epoch, report.payloads).await; } Ok(stop) diff --git a/src/runtime/streaming/job/job_manager.rs b/src/runtime/streaming/job/job_manager.rs index 2d828a5c..130b76b9 100644 --- a/src/runtime/streaming/job/job_manager.rs +++ b/src/runtime/streaming/job/job_manager.rs @@ -720,7 +720,10 @@ impl JobManager { // Chandy-Lamport distributed snapshot barrier coordinator // ======================================================================== - fn spawn_checkpoint_coordinator(&self, cfg: CheckpointCoordinatorConfig) -> TokioJoinHandle<()> { + fn spawn_checkpoint_coordinator( + &self, + cfg: CheckpointCoordinatorConfig, + ) -> TokioJoinHandle<()> { let CheckpointCoordinatorConfig { job_id, source_control_txs, From 9eea91a41ccb7b906866d74177212514c6647bd8 Mon Sep 17 00:00:00 2001 From: luoluoyuyu Date: Mon, 20 Apr 2026 01:14:27 +0800 Subject: [PATCH 23/26] update --- src/runtime/streaming/api/operator.rs | 24 ++- .../streaming/execution/operator_chain.rs | 150 ++++++++++++++++-- .../grouping/incremental_aggregate.rs | 17 +- .../operators/joins/join_instance.rs | 22 +-- .../operators/joins/join_with_expiration.rs | 26 +-- src/runtime/streaming/operators/key_by.rs | 20 ++- .../streaming/operators/key_operator.rs | 18 ++- src/runtime/streaming/operators/projection.rs | 20 ++- .../streaming/operators/sink/kafka/mod.rs | 12 +- .../streaming/operators/value_execution.rs | 21 ++- .../watermark/watermark_generator.rs | 52 +++--- .../windows/session_aggregating_window.rs | 21 ++- .../windows/sliding_aggregating_window.rs | 20 +-- .../windows/tumbling_aggregating_window.rs | 24 +-- .../operators/windows/window_function.rs | 18 +-- 15 files changed, 312 insertions(+), 153 deletions(-) diff --git a/src/runtime/streaming/api/operator.rs b/src/runtime/streaming/api/operator.rs index 53fa629f..8eb9e8c4 100644 --- a/src/runtime/streaming/api/operator.rs +++ b/src/runtime/streaming/api/operator.rs @@ -16,7 +16,6 @@ use crate::runtime::streaming::protocol::event::StreamOutput; use crate::sql::common::{CheckpointBarrier, Watermark}; use arrow_array::RecordBatch; use async_trait::async_trait; -use std::time::Duration; // --------------------------------------------------------------------------- // ConstructedOperator @@ -27,6 +26,11 @@ pub enum ConstructedOperator { Operator(Box), } +#[async_trait] +pub trait Collector: Send { + async fn collect(&mut self, out: StreamOutput, ctx: &mut TaskContext) -> anyhow::Result<()>; +} + #[async_trait] pub trait Operator: Send + 'static { fn name(&self) -> &str; @@ -40,13 +44,15 @@ pub trait Operator: Send + 'static { input_idx: usize, batch: RecordBatch, ctx: &mut TaskContext, - ) -> anyhow::Result>; + collector: &mut dyn Collector, + ) -> anyhow::Result<()>; async fn process_watermark( &mut self, watermark: Watermark, ctx: &mut TaskContext, - ) -> anyhow::Result>; + collector: &mut dyn Collector, + ) -> anyhow::Result<()>; async fn snapshot_state( &mut self, @@ -75,18 +81,6 @@ pub trait Operator: Send + 'static { Ok(()) } - fn tick_interval(&self) -> Option { - None - } - - async fn process_tick( - &mut self, - _tick_index: u64, - _ctx: &mut TaskContext, - ) -> anyhow::Result> { - Ok(vec![]) - } - async fn on_close(&mut self, _ctx: &mut TaskContext) -> anyhow::Result> { Ok(vec![]) } diff --git a/src/runtime/streaming/execution/operator_chain.rs b/src/runtime/streaming/execution/operator_chain.rs index 0b76568b..6f592eca 100644 --- a/src/runtime/streaming/execution/operator_chain.rs +++ b/src/runtime/streaming/execution/operator_chain.rs @@ -13,7 +13,7 @@ use async_trait::async_trait; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::operator::Operator; +use crate::runtime::streaming::api::operator::{Collector, Operator}; use crate::runtime::streaming::error::RunError; use crate::runtime::streaming::protocol::{ control::{ControlCommand, StopMode}, @@ -120,13 +120,103 @@ impl OperatorDrive for IntermediateDriver { ) -> Result { match tracked.event { StreamEvent::Data(batch) => { - let outputs = self.operator.process_data(input_idx, batch, ctx).await?; - self.dispatch_outputs(outputs, ctx).await?; + struct NextCollector<'a> { + next: &'a mut Box, + op_name: String, + } + #[async_trait] + impl Collector for NextCollector<'_> { + async fn collect( + &mut self, + out: StreamOutput, + ctx: &mut TaskContext, + ) -> anyhow::Result<()> { + match out { + StreamOutput::Forward(b) => { + self.next + .process_event( + 0, + TrackedEvent::control(StreamEvent::Data(b)), + ctx, + ) + .await?; + } + StreamOutput::Watermark(wm) => { + self.next + .process_event( + 0, + TrackedEvent::control(StreamEvent::Watermark(wm)), + ctx, + ) + .await?; + } + StreamOutput::Keyed(_, _) | StreamOutput::Broadcast(_) => { + return Err(anyhow::anyhow!( + "Topology Violation: Keyed or Broadcast output emitted in the middle of chain by '{}'", + self.op_name + )); + } + } + Ok(()) + } + } + let mut collector = NextCollector { + next: &mut self.next, + op_name: self.operator.name().to_string(), + }; + self.operator + .process_data(input_idx, batch, ctx, &mut collector) + .await?; Ok(false) } StreamEvent::Watermark(wm) => { - let outputs = self.operator.process_watermark(wm, ctx).await?; - self.dispatch_outputs(outputs, ctx).await?; + struct NextCollector<'a> { + next: &'a mut Box, + op_name: String, + } + #[async_trait] + impl Collector for NextCollector<'_> { + async fn collect( + &mut self, + out: StreamOutput, + ctx: &mut TaskContext, + ) -> anyhow::Result<()> { + match out { + StreamOutput::Forward(b) => { + self.next + .process_event( + 0, + TrackedEvent::control(StreamEvent::Data(b)), + ctx, + ) + .await?; + } + StreamOutput::Watermark(wm) => { + self.next + .process_event( + 0, + TrackedEvent::control(StreamEvent::Watermark(wm)), + ctx, + ) + .await?; + } + StreamOutput::Keyed(_, _) | StreamOutput::Broadcast(_) => { + return Err(anyhow::anyhow!( + "Topology Violation: Keyed or Broadcast output emitted in the middle of chain by '{}'", + self.op_name + )); + } + } + Ok(()) + } + } + let mut collector = NextCollector { + next: &mut self.next, + op_name: self.operator.name().to_string(), + }; + self.operator + .process_watermark(wm, ctx, &mut collector) + .await?; self.forward_signal(StreamEvent::Watermark(wm), ctx).await?; Ok(false) } @@ -237,13 +327,55 @@ impl OperatorDrive for TailDriver { ) -> Result { match tracked.event { StreamEvent::Data(batch) => { - let outputs = self.operator.process_data(input_idx, batch, ctx).await?; - self.dispatch_outputs(outputs, ctx).await?; + struct FinalCollector; + #[async_trait] + impl Collector for FinalCollector { + async fn collect( + &mut self, + out: StreamOutput, + ctx: &mut TaskContext, + ) -> anyhow::Result<()> { + match out { + StreamOutput::Forward(b) => ctx.collect(b).await?, + StreamOutput::Keyed(hash, b) => ctx.collect_keyed(hash, b).await?, + StreamOutput::Broadcast(b) => ctx.collect(b).await?, + StreamOutput::Watermark(wm) => { + ctx.broadcast(StreamEvent::Watermark(wm)).await? + } + } + Ok(()) + } + } + let mut collector = FinalCollector; + self.operator + .process_data(input_idx, batch, ctx, &mut collector) + .await?; Ok(false) } StreamEvent::Watermark(wm) => { - let outputs = self.operator.process_watermark(wm, ctx).await?; - self.dispatch_outputs(outputs, ctx).await?; + struct FinalCollector; + #[async_trait] + impl Collector for FinalCollector { + async fn collect( + &mut self, + out: StreamOutput, + ctx: &mut TaskContext, + ) -> anyhow::Result<()> { + match out { + StreamOutput::Forward(b) => ctx.collect(b).await?, + StreamOutput::Keyed(hash, b) => ctx.collect_keyed(hash, b).await?, + StreamOutput::Broadcast(b) => ctx.collect(b).await?, + StreamOutput::Watermark(wm) => { + ctx.broadcast(StreamEvent::Watermark(wm)).await? + } + } + Ok(()) + } + } + let mut collector = FinalCollector; + self.operator + .process_watermark(wm, ctx, &mut collector) + .await?; self.forward_signal(StreamEvent::Watermark(wm), ctx).await?; Ok(false) } diff --git a/src/runtime/streaming/operators/grouping/incremental_aggregate.rs b/src/runtime/streaming/operators/grouping/incremental_aggregate.rs index a8983e99..a2325e7c 100644 --- a/src/runtime/streaming/operators/grouping/incremental_aggregate.rs +++ b/src/runtime/streaming/operators/grouping/incremental_aggregate.rs @@ -43,7 +43,7 @@ use tracing::{debug, info, warn}; // ========================================================================= use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::operator::Operator; +use crate::runtime::streaming::api::operator::{Collector, Operator}; use crate::runtime::streaming::factory::Registry; use crate::runtime::streaming::operators::{Key, UpdatingCache}; use crate::runtime::streaming::state::OperatorStateStore; @@ -876,26 +876,29 @@ impl Operator for IncrementalAggregatingFunc { _input_idx: usize, batch: RecordBatch, _ctx: &mut TaskContext, - ) -> Result> { + _collector: &mut dyn Collector, + ) -> Result<()> { if self.has_routing_keys { self.keyed_aggregate(&batch)?; } else { self.global_aggregate(&batch)?; } - Ok(vec![]) + Ok(()) } async fn process_watermark( &mut self, _watermark: Watermark, _ctx: &mut TaskContext, - ) -> Result> { + collector: &mut dyn Collector, + ) -> Result<()> { if let Some(changelog_batch) = self.generate_changelog()? { - Ok(vec![StreamOutput::Forward(changelog_batch)]) - } else { - Ok(vec![]) + collector + .collect(StreamOutput::Forward(changelog_batch), _ctx) + .await?; } + Ok(()) } async fn snapshot_state( diff --git a/src/runtime/streaming/operators/joins/join_instance.rs b/src/runtime/streaming/operators/joins/join_instance.rs index cddeeff2..098e5a73 100644 --- a/src/runtime/streaming/operators/joins/join_instance.rs +++ b/src/runtime/streaming/operators/joins/join_instance.rs @@ -27,7 +27,7 @@ use tracing::{info, warn}; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::operator::Operator; +use crate::runtime::streaming::api::operator::{Collector, Operator}; use crate::runtime::streaming::factory::Registry; use crate::runtime::streaming::state::OperatorStateStore; use crate::sql::common::{CheckpointBarrier, FsSchema, FsSchemaRef, Watermark}; @@ -257,23 +257,25 @@ impl Operator for InstantJoinOperator { input_idx: usize, batch: RecordBatch, ctx: &mut TaskContext, - ) -> Result> { + _collector: &mut dyn Collector, + ) -> Result<()> { let side = if input_idx == 0 { JoinSide::Left } else { JoinSide::Right }; self.process_side_internal(side, batch, ctx).await?; - Ok(vec![]) + Ok(()) } async fn process_watermark( &mut self, watermark: Watermark, _ctx: &mut TaskContext, - ) -> Result> { + collector: &mut dyn Collector, + ) -> Result<()> { let Watermark::EventTime(current_time) = watermark else { - return Ok(vec![]); + return Ok(()); }; let store = self.state_store.clone().unwrap(); let cutoff_nanos = current_time.duration_since(UNIX_EPOCH).unwrap().as_nanos() as u64; @@ -288,7 +290,7 @@ impl Operator for InstantJoinOperator { .collect(); if expired_ts.is_empty() { - return Ok(vec![]); + return Ok(()); } // Phase 1: Harvest — extract all expired timestamp data from LSM-Tree @@ -323,15 +325,15 @@ impl Operator for InstantJoinOperator { } // Phase 2: Compute — all data extracted, no store reference held - let mut emit_outputs = Vec::new(); - for (_, left_input, right_input) in pending_pairs { if left_input.num_rows() == 0 && right_input.num_rows() == 0 { continue; } let results = self.compute_pair(left_input, right_input).await?; for batch in results { - emit_outputs.push(StreamOutput::Forward(batch)); + collector + .collect(StreamOutput::Forward(batch), _ctx) + .await?; } } @@ -347,7 +349,7 @@ impl Operator for InstantJoinOperator { self.right_state.active_timestamps.remove(&ts); } - Ok(emit_outputs) + Ok(()) } async fn snapshot_state( diff --git a/src/runtime/streaming/operators/joins/join_with_expiration.rs b/src/runtime/streaming/operators/joins/join_with_expiration.rs index 92089428..6a2a240c 100644 --- a/src/runtime/streaming/operators/joins/join_with_expiration.rs +++ b/src/runtime/streaming/operators/joins/join_with_expiration.rs @@ -26,7 +26,7 @@ use tracing::{info, warn}; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::operator::Operator; +use crate::runtime::streaming::api::operator::{Collector, Operator}; use crate::runtime::streaming::factory::Registry; use crate::runtime::streaming::state::OperatorStateStore; use crate::sql::common::{CheckpointBarrier, FsSchema, Watermark}; @@ -175,7 +175,8 @@ impl JoinWithExpirationOperator { side: JoinSide, batch: RecordBatch, ctx: &mut TaskContext, - ) -> Result> { + collector: &mut dyn Collector, + ) -> Result<()> { let current_time = ctx.current_watermark().unwrap_or_else(SystemTime::now); let store = self .state_store @@ -204,7 +205,7 @@ impl JoinWithExpirationOperator { }; if opposite_batches.is_empty() { - return Ok(vec![]); + return Ok(()); } let opposite_schema = match side { @@ -224,11 +225,10 @@ impl JoinWithExpirationOperator { }; let result_batches = self.compute_pair(left_input, right_input).await?; - - Ok(result_batches - .into_iter() - .map(StreamOutput::Forward) - .collect()) + for b in result_batches { + collector.collect(StreamOutput::Forward(b), ctx).await?; + } + Ok(()) } } @@ -284,21 +284,23 @@ impl Operator for JoinWithExpirationOperator { input_idx: usize, batch: RecordBatch, ctx: &mut TaskContext, - ) -> Result> { + collector: &mut dyn Collector, + ) -> Result<()> { let side = if input_idx == 0 { JoinSide::Left } else { JoinSide::Right }; - self.process_side(side, batch, ctx).await + self.process_side(side, batch, ctx, collector).await } async fn process_watermark( &mut self, _watermark: Watermark, _ctx: &mut TaskContext, - ) -> Result> { - Ok(vec![]) + _collector: &mut dyn Collector, + ) -> Result<()> { + Ok(()) } async fn snapshot_state( diff --git a/src/runtime/streaming/operators/key_by.rs b/src/runtime/streaming/operators/key_by.rs index 59206688..90c55d08 100644 --- a/src/runtime/streaming/operators/key_by.rs +++ b/src/runtime/streaming/operators/key_by.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::operator::Operator; +use crate::runtime::streaming::api::operator::{Collector, Operator}; use crate::sql::common::{CheckpointBarrier, Watermark}; use protocol::function_stream_graph::KeyPlanOperator; @@ -57,10 +57,11 @@ impl Operator for KeyByOperator { _input_idx: usize, batch: RecordBatch, _ctx: &mut TaskContext, - ) -> Result> { + collector: &mut dyn Collector, + ) -> Result<()> { let num_rows = batch.num_rows(); if num_rows == 0 { - return Ok(vec![]); + return Ok(()); } let mut key_columns = Vec::with_capacity(self.key_extractors.len()); @@ -110,15 +111,22 @@ impl Operator for KeyByOperator { start_idx = end_idx; } - Ok(outputs) + for out in outputs { + collector.collect(out, _ctx).await?; + } + Ok(()) } async fn process_watermark( &mut self, watermark: Watermark, _ctx: &mut TaskContext, - ) -> Result> { - Ok(vec![StreamOutput::Watermark(watermark)]) + collector: &mut dyn Collector, + ) -> Result<()> { + collector + .collect(StreamOutput::Watermark(watermark), _ctx) + .await?; + Ok(()) } async fn snapshot_state( diff --git a/src/runtime/streaming/operators/key_operator.rs b/src/runtime/streaming/operators/key_operator.rs index 1f4f48c6..7a89d2f2 100644 --- a/src/runtime/streaming/operators/key_operator.rs +++ b/src/runtime/streaming/operators/key_operator.rs @@ -17,7 +17,7 @@ use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::operator::Operator; +use crate::runtime::streaming::api::operator::{Collector, Operator}; use crate::runtime::streaming::operators::StatelessPhysicalExecutor; use crate::sql::common::{CheckpointBarrier, Watermark}; use ahash::RandomState; @@ -67,7 +67,8 @@ impl Operator for KeyExecutionOperator { _input_idx: usize, batch: RecordBatch, _ctx: &mut TaskContext, - ) -> Result> { + collector: &mut dyn Collector, + ) -> Result<()> { let mut outputs = Vec::new(); let mut stream = self.executor.process_batch(batch).await?; @@ -122,15 +123,22 @@ impl Operator for KeyExecutionOperator { start_idx = end_idx; } } - Ok(outputs) + for out in outputs { + collector.collect(out, _ctx).await?; + } + Ok(()) } async fn process_watermark( &mut self, watermark: Watermark, _ctx: &mut TaskContext, - ) -> Result> { - Ok(vec![StreamOutput::Watermark(watermark)]) + collector: &mut dyn Collector, + ) -> Result<()> { + collector + .collect(StreamOutput::Watermark(watermark), _ctx) + .await?; + Ok(()) } async fn snapshot_state( diff --git a/src/runtime/streaming/operators/projection.rs b/src/runtime/streaming/operators/projection.rs index 1a2ff3a1..b84d74aa 100644 --- a/src/runtime/streaming/operators/projection.rs +++ b/src/runtime/streaming/operators/projection.rs @@ -24,7 +24,7 @@ use protocol::function_stream_graph::ProjectionOperator as ProjectionOperatorPro use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::operator::Operator; +use crate::runtime::streaming::api::operator::{Collector, Operator}; use crate::runtime::streaming::factory::global::Registry; use crate::sql::common::{CheckpointBarrier, FsSchema, FsSchemaRef, Watermark}; use crate::sql::logical_node::logical::OperatorName; @@ -98,9 +98,10 @@ impl Operator for ProjectionOperator { _input_idx: usize, batch: RecordBatch, _ctx: &mut TaskContext, - ) -> Result> { + collector: &mut dyn Collector, + ) -> Result<()> { if batch.num_rows() == 0 { - return Ok(vec![]); + return Ok(()); } let projected_columns = self @@ -114,15 +115,22 @@ impl Operator for ProjectionOperator { let out_batch = RecordBatch::try_new(self.output_schema.schema.clone(), projected_columns)?; - Ok(vec![StreamOutput::Forward(out_batch)]) + collector + .collect(StreamOutput::Forward(out_batch), _ctx) + .await?; + Ok(()) } async fn process_watermark( &mut self, watermark: Watermark, _ctx: &mut TaskContext, - ) -> Result> { - Ok(vec![StreamOutput::Watermark(watermark)]) + collector: &mut dyn Collector, + ) -> Result<()> { + collector + .collect(StreamOutput::Watermark(watermark), _ctx) + .await?; + Ok(()) } async fn snapshot_state( diff --git a/src/runtime/streaming/operators/sink/kafka/mod.rs b/src/runtime/streaming/operators/sink/kafka/mod.rs index bee4d367..a9c4b50e 100644 --- a/src/runtime/streaming/operators/sink/kafka/mod.rs +++ b/src/runtime/streaming/operators/sink/kafka/mod.rs @@ -38,7 +38,7 @@ use tracing::{info, warn}; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::operator::Operator; +use crate::runtime::streaming::api::operator::{Collector, Operator}; use crate::runtime::streaming::format::DataSerializer; use crate::sql::common::constants::factory_operator_name; use crate::sql::common::{CheckpointBarrier, FsSchema, Watermark}; @@ -260,7 +260,8 @@ impl Operator for KafkaSinkOperator { _input_idx: usize, batch: RecordBatch, _ctx: &mut TaskContext, - ) -> Result> { + _collector: &mut dyn Collector, + ) -> Result<()> { let payloads = self.serializer.serialize(&batch)?; let producer = self.current_producer().clone(); @@ -298,15 +299,16 @@ impl Operator for KafkaSinkOperator { } } - Ok(vec![]) + Ok(()) } async fn process_watermark( &mut self, _watermark: Watermark, _ctx: &mut TaskContext, - ) -> Result> { - Ok(vec![]) + _collector: &mut dyn Collector, + ) -> Result<()> { + Ok(()) } async fn snapshot_state( diff --git a/src/runtime/streaming/operators/value_execution.rs b/src/runtime/streaming/operators/value_execution.rs index ff952dda..b93cd78b 100644 --- a/src/runtime/streaming/operators/value_execution.rs +++ b/src/runtime/streaming/operators/value_execution.rs @@ -17,7 +17,7 @@ use futures::StreamExt; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::operator::Operator; +use crate::runtime::streaming::api::operator::{Collector, Operator}; use crate::runtime::streaming::operators::StatelessPhysicalExecutor; use crate::sql::common::{CheckpointBarrier, Watermark}; @@ -43,26 +43,31 @@ impl Operator for ValueExecutionOperator { _input_idx: usize, batch: RecordBatch, _ctx: &mut TaskContext, - ) -> Result> { - let mut outputs = Vec::new(); - + collector: &mut dyn Collector, + ) -> Result<()> { let mut stream = self.executor.process_batch(batch).await?; while let Some(batch_result) = stream.next().await { let out_batch = batch_result?; if out_batch.num_rows() > 0 { - outputs.push(StreamOutput::Forward(out_batch)); + collector + .collect(StreamOutput::Forward(out_batch), _ctx) + .await?; } } - Ok(outputs) + Ok(()) } async fn process_watermark( &mut self, watermark: Watermark, _ctx: &mut TaskContext, - ) -> Result> { - Ok(vec![StreamOutput::Watermark(watermark)]) + collector: &mut dyn Collector, + ) -> Result<()> { + collector + .collect(StreamOutput::Watermark(watermark), _ctx) + .await?; + Ok(()) } async fn snapshot_state( diff --git a/src/runtime/streaming/operators/watermark/watermark_generator.rs b/src/runtime/streaming/operators/watermark/watermark_generator.rs index b74a92f2..497553eb 100644 --- a/src/runtime/streaming/operators/watermark/watermark_generator.rs +++ b/src/runtime/streaming/operators/watermark/watermark_generator.rs @@ -23,11 +23,11 @@ use datafusion_proto::protobuf::PhysicalExprNode; use prost::Message; use std::sync::Arc; use std::time::{Duration, SystemTime}; -use tracing::{debug, info}; +use tracing::debug; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::operator::Operator; +use crate::runtime::streaming::api::operator::{Collector, Operator}; use crate::runtime::streaming::factory::Registry; use crate::sql::common::{CheckpointBarrier, FsSchema, Watermark, from_nanos, to_millis}; use async_trait::async_trait; @@ -107,10 +107,6 @@ impl Operator for WatermarkGeneratorOperator { "ExpressionWatermarkGenerator" } - fn tick_interval(&self) -> Option { - Some(Duration::from_secs(1)) - } - async fn on_start(&mut self, _ctx: &mut TaskContext) -> Result<()> { self.last_event_wall = SystemTime::now(); Ok(()) @@ -121,13 +117,16 @@ impl Operator for WatermarkGeneratorOperator { _input_idx: usize, batch: RecordBatch, ctx: &mut TaskContext, - ) -> Result> { + collector: &mut dyn Collector, + ) -> Result<()> { self.last_event_wall = SystemTime::now(); - let mut outputs = vec![StreamOutput::Forward(batch.clone())]; + collector + .collect(StreamOutput::Forward(batch.clone()), ctx) + .await?; let Some(max_batch_ts) = self.extract_max_timestamp(&batch) else { - return Ok(outputs); + return Ok(()); }; let new_watermark = self.evaluate_watermark(&batch)?; @@ -145,42 +144,27 @@ impl Operator for WatermarkGeneratorOperator { to_millis(self.state.max_watermark) ); - outputs.push(StreamOutput::Watermark(Watermark::EventTime( - self.state.max_watermark, - ))); + collector + .collect( + StreamOutput::Watermark(Watermark::EventTime(self.state.max_watermark)), + ctx, + ) + .await?; self.state.last_watermark_emitted_at = max_batch_ts; self.is_idle = false; } - Ok(outputs) + Ok(()) } async fn process_watermark( &mut self, _watermark: Watermark, _ctx: &mut TaskContext, - ) -> Result> { - Ok(vec![]) - } - - async fn process_tick( - &mut self, - _tick_index: u64, - ctx: &mut TaskContext, - ) -> Result> { - if let Some(idle_timeout) = self.idle_time { - let elapsed = self.last_event_wall.elapsed().unwrap_or(Duration::ZERO); - if !self.is_idle && elapsed > idle_timeout { - info!( - "task [{}] entering Idle after {:?}", - ctx.subtask_index, idle_timeout - ); - self.is_idle = true; - return Ok(vec![StreamOutput::Watermark(Watermark::Idle)]); - } - } - Ok(vec![]) + _collector: &mut dyn Collector, + ) -> Result<()> { + Ok(()) } async fn snapshot_state( diff --git a/src/runtime/streaming/operators/windows/session_aggregating_window.rs b/src/runtime/streaming/operators/windows/session_aggregating_window.rs index ad32f73f..2da2c285 100644 --- a/src/runtime/streaming/operators/windows/session_aggregating_window.rs +++ b/src/runtime/streaming/operators/windows/session_aggregating_window.rs @@ -38,7 +38,7 @@ use tracing::info; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::operator::Operator; +use crate::runtime::streaming::api::operator::{Collector, Operator}; use crate::runtime::streaming::factory::Registry; use crate::runtime::streaming::state::OperatorStateStore; use crate::sql::common::converter::Converter; @@ -797,12 +797,13 @@ impl Operator for SessionWindowOperator { _input_idx: usize, batch: RecordBatch, ctx: &mut TaskContext, - ) -> Result> { + _collector: &mut dyn Collector, + ) -> Result<()> { let watermark_time = ctx.current_watermark(); let filtered_batch = self.filter_batch_by_time(batch, watermark_time)?; if filtered_batch.num_rows() == 0 { - return Ok(vec![]); + return Ok(()); } let sorted_batch = self.sort_batch(&filtered_batch)?; @@ -810,7 +811,7 @@ impl Operator for SessionWindowOperator { self.ingest_sorted_batch(sorted_batch, watermark_time, false) .await?; - Ok(vec![]) + Ok(()) } // Watermark-driven session closure with precise LSM-Tree garbage collection @@ -818,14 +819,15 @@ impl Operator for SessionWindowOperator { &mut self, watermark: Watermark, _ctx: &mut TaskContext, - ) -> Result> { + collector: &mut dyn Collector, + ) -> Result<()> { let Watermark::EventTime(current_time) = watermark else { - return Ok(vec![]); + return Ok(()); }; let completed_sessions = self.evaluate_watermark_with_meta(current_time).await?; if completed_sessions.is_empty() { - return Ok(vec![]); + return Ok(()); } let store = self @@ -855,7 +857,10 @@ impl Operator for SessionWindowOperator { } let output_batch = self.format_to_arrow(completed_sessions)?; - Ok(vec![StreamOutput::Forward(output_batch)]) + collector + .collect(StreamOutput::Forward(output_batch), _ctx) + .await?; + Ok(()) } async fn snapshot_state( diff --git a/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs b/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs index 64d09b8d..3516e950 100644 --- a/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs +++ b/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs @@ -35,7 +35,7 @@ use tracing::info; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::operator::Operator; +use crate::runtime::streaming::api::operator::{Collector, Operator}; use crate::runtime::streaming::factory::Registry; use crate::runtime::streaming::state::OperatorStateStore; use crate::sql::common::{CheckpointBarrier, FsSchema, Watermark, from_nanos, to_nanos}; @@ -424,7 +424,8 @@ impl Operator for SlidingWindowOperator { _input_idx: usize, batch: RecordBatch, ctx: &mut TaskContext, - ) -> Result> { + _collector: &mut dyn Collector, + ) -> Result<()> { let bin_array = self .binning_function .evaluate(&batch)? @@ -486,7 +487,7 @@ impl Operator for SlidingWindowOperator { .map_err(|e| anyhow!("partial channel send: {e}"))?; } - Ok(vec![]) + Ok(()) } // State morphing (Type 0 → Type 1) and dual-layer GC @@ -494,9 +495,10 @@ impl Operator for SlidingWindowOperator { &mut self, watermark: Watermark, _ctx: &mut TaskContext, - ) -> Result> { + collector: &mut dyn Collector, + ) -> Result<()> { let Watermark::EventTime(current_time) = watermark else { - return Ok(vec![]); + return Ok(()); }; let watermark_bin = self.bin_start(current_time); let store = self @@ -504,8 +506,6 @@ impl Operator for SlidingWindowOperator { .clone() .expect("State store not initialized"); - let mut final_outputs = Vec::new(); - let mut expired_bins = Vec::new(); for &k in self.active_bins.keys() { if k + self.slide <= watermark_bin { @@ -578,7 +578,9 @@ impl Operator for SlidingWindowOperator { .execute(0, SessionContext::new().task_ctx())?; while let Some(batch) = proj_exec.next().await { - final_outputs.push(StreamOutput::Forward(batch?)); + collector + .collect(StreamOutput::Forward(batch?), _ctx) + .await?; } // Phase 5: GC expired partial bins (Type 1) that fall outside the window @@ -600,7 +602,7 @@ impl Operator for SlidingWindowOperator { } } - Ok(final_outputs) + Ok(()) } async fn snapshot_state( diff --git a/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs b/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs index 4e48c50c..6b6b6029 100644 --- a/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs +++ b/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs @@ -36,7 +36,7 @@ use tracing::{info, warn}; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::operator::Operator; +use crate::runtime::streaming::api::operator::{Collector, Operator}; use crate::runtime::streaming::factory::Registry; use crate::runtime::streaming::state::OperatorStateStore; use crate::sql::common::time_utils::print_time; @@ -232,7 +232,8 @@ impl Operator for TumblingWindowOperator { _input_idx: usize, batch: RecordBatch, ctx: &mut TaskContext, - ) -> Result> { + _collector: &mut dyn Collector, + ) -> Result<()> { let bin_array = self .binning_function .evaluate(&batch)? @@ -298,7 +299,7 @@ impl Operator for TumblingWindowOperator { .map_err(|e| anyhow!("partial channel send: {e}"))?; } - Ok(vec![]) + Ok(()) } // Watermark-driven window closure with LSM-Tree GC @@ -306,17 +307,16 @@ impl Operator for TumblingWindowOperator { &mut self, watermark: Watermark, _ctx: &mut TaskContext, - ) -> Result> { + collector: &mut dyn Collector, + ) -> Result<()> { let Watermark::EventTime(current_time) = watermark else { - return Ok(vec![]); + return Ok(()); }; let store = self .state_store .as_ref() .expect("State store not initialized"); - let mut final_outputs = Vec::new(); - let mut expired_bins = Vec::new(); for &k in self.active_bins.keys() { if k + self.width <= current_time { @@ -353,7 +353,9 @@ impl Operator for TumblingWindowOperator { )?; if self.final_projection.is_none() { - final_outputs.push(StreamOutput::Forward(with_timestamp)); + collector + .collect(StreamOutput::Forward(with_timestamp), _ctx) + .await?; } else { aggregate_results.push(with_timestamp); } @@ -366,7 +368,9 @@ impl Operator for TumblingWindowOperator { final_projection.execute(0, SessionContext::new().task_ctx())?; while let Some(batch) = proj_exec.next().await { - final_outputs.push(StreamOutput::Forward(batch?)); + collector + .collect(StreamOutput::Forward(batch?), _ctx) + .await?; } } @@ -378,7 +382,7 @@ impl Operator for TumblingWindowOperator { self.pending_bins.remove(&bin_start_nanos); } - Ok(final_outputs) + Ok(()) } async fn snapshot_state( diff --git a/src/runtime/streaming/operators/windows/window_function.rs b/src/runtime/streaming/operators/windows/window_function.rs index 585e51bb..1249233e 100644 --- a/src/runtime/streaming/operators/windows/window_function.rs +++ b/src/runtime/streaming/operators/windows/window_function.rs @@ -28,7 +28,7 @@ use tracing::{info, warn}; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::operator::Operator; +use crate::runtime::streaming::api::operator::{Collector, Operator}; use crate::runtime::streaming::factory::Registry; use crate::runtime::streaming::state::OperatorStateStore; use crate::sql::common::time_utils::print_time; @@ -182,7 +182,8 @@ impl Operator for WindowFunctionOperator { _input_idx: usize, batch: RecordBatch, ctx: &mut TaskContext, - ) -> Result> { + _collector: &mut dyn Collector, + ) -> Result<()> { let current_watermark = ctx.current_watermark(); let split_batches = self.filter_and_split_batches(batch, current_watermark)?; let store = self @@ -202,7 +203,7 @@ impl Operator for WindowFunctionOperator { self.pending_timestamps.insert(ts_nanos); } - Ok(vec![]) + Ok(()) } // On-demand compute & GC: pull data from LSM-Tree, run DataFusion, tombstone @@ -210,9 +211,10 @@ impl Operator for WindowFunctionOperator { &mut self, watermark: Watermark, _ctx: &mut TaskContext, - ) -> Result> { + collector: &mut dyn Collector, + ) -> Result<()> { let Watermark::EventTime(current_time) = watermark else { - return Ok(vec![]); + return Ok(()); }; let store = self .state_store @@ -227,8 +229,6 @@ impl Operator for WindowFunctionOperator { .copied() .collect(); - let mut final_outputs = Vec::new(); - for ts in expired_ts { let key = Self::build_state_key(ts); @@ -250,7 +250,7 @@ impl Operator for WindowFunctionOperator { drop(tx); while let Some(res) = stream.next().await { - final_outputs.push(StreamOutput::Forward(res?)); + collector.collect(StreamOutput::Forward(res?), _ctx).await?; } } @@ -258,7 +258,7 @@ impl Operator for WindowFunctionOperator { self.pending_timestamps.remove(&ts); } - Ok(final_outputs) + Ok(()) } async fn snapshot_state( From 1984b30c10272b5e954df1fc9323f0e6fe3debe5 Mon Sep 17 00:00:00 2001 From: luoluoyuyu Date: Mon, 20 Apr 2026 01:33:58 +0800 Subject: [PATCH 24/26] update --- .../streaming/execution/operator_chain.rs | 274 +++++++----------- 1 file changed, 97 insertions(+), 177 deletions(-) diff --git a/src/runtime/streaming/execution/operator_chain.rs b/src/runtime/streaming/execution/operator_chain.rs index 6f592eca..88e8f441 100644 --- a/src/runtime/streaming/execution/operator_chain.rs +++ b/src/runtime/streaming/execution/operator_chain.rs @@ -10,6 +10,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use anyhow::anyhow; use async_trait::async_trait; use crate::runtime::streaming::api::context::TaskContext; @@ -21,29 +22,39 @@ use crate::runtime::streaming::protocol::{ }; use crate::sql::common::CheckpointBarrier; +// ============================================================================ +// Core Traits +// ============================================================================ + #[async_trait] pub trait OperatorDrive: Send { async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<(), RunError>; + async fn process_event( &mut self, input_idx: usize, event: TrackedEvent, ctx: &mut TaskContext, ) -> Result; + async fn handle_control( &mut self, cmd: ControlCommand, ctx: &mut TaskContext, ) -> Result; + async fn on_close(&mut self, ctx: &mut TaskContext) -> Result<(), RunError>; } +// ============================================================================ +// Chain Builder +// ============================================================================ + pub struct ChainBuilder; impl ChainBuilder { pub fn build(mut operators: Vec>) -> Option> { let tail_operator = operators.pop()?; - let mut current_driver: Box = Box::new(TailDriver::new(tail_operator)); while let Some(op) = operators.pop() { @@ -54,6 +65,68 @@ impl ChainBuilder { } } +// ============================================================================ +// Collectors (Zero-Allocation Emission Abstractions) +// ============================================================================ + +struct ChainedCollector<'a> { + next: &'a mut dyn OperatorDrive, + op_name: String, +} + +impl<'a> ChainedCollector<'a> { + fn new(next: &'a mut dyn OperatorDrive, op_name: &str) -> Self { + Self { + next, + op_name: op_name.to_string(), + } + } +} + +#[async_trait] +impl<'a> Collector for ChainedCollector<'a> { + async fn collect(&mut self, out: StreamOutput, ctx: &mut TaskContext) -> anyhow::Result<()> { + match out { + StreamOutput::Forward(b) => { + self.next + .process_event(0, TrackedEvent::control(StreamEvent::Data(b)), ctx) + .await?; + } + StreamOutput::Watermark(wm) => { + self.next + .process_event(0, TrackedEvent::control(StreamEvent::Watermark(wm)), ctx) + .await?; + } + StreamOutput::Keyed(_, _) | StreamOutput::Broadcast(_) => { + return Err(anyhow!( + "Topology Violation: Keyed or Broadcast output emitted in the middle of chain by '{}'", + self.op_name + )); + } + } + Ok(()) + } +} + +struct TaskCollector; + +#[async_trait] +impl Collector for TaskCollector { + async fn collect(&mut self, out: StreamOutput, ctx: &mut TaskContext) -> anyhow::Result<()> { + match out { + StreamOutput::Forward(b) => ctx.collect(b).await?, + StreamOutput::Keyed(hash, b) => ctx.collect_keyed(hash, b).await?, + StreamOutput::Broadcast(b) => ctx.collect(b).await?, + StreamOutput::Watermark(wm) => ctx.broadcast(StreamEvent::Watermark(wm)).await?, + } + Ok(()) + } +} + +// ============================================================================ +// Intermediate Driver (Middle of the Chain) +// ============================================================================ + pub struct IntermediateDriver { operator: Box, next: Box, @@ -64,34 +137,6 @@ impl IntermediateDriver { Self { operator, next } } - async fn dispatch_outputs( - &mut self, - outputs: Vec, - ctx: &mut TaskContext, - ) -> Result<(), RunError> { - for out in outputs { - match out { - StreamOutput::Forward(b) => { - self.next - .process_event(0, TrackedEvent::control(StreamEvent::Data(b)), ctx) - .await?; - } - StreamOutput::Watermark(wm) => { - self.next - .process_event(0, TrackedEvent::control(StreamEvent::Watermark(wm)), ctx) - .await?; - } - StreamOutput::Keyed(_, _) | StreamOutput::Broadcast(_) => { - return Err(RunError::internal(format!( - "Topology Violation: Keyed or Broadcast output emitted in the middle of chain by '{}'", - self.operator.name() - ))); - } - } - } - Ok(()) - } - async fn forward_signal( &mut self, event: StreamEvent, @@ -120,100 +165,14 @@ impl OperatorDrive for IntermediateDriver { ) -> Result { match tracked.event { StreamEvent::Data(batch) => { - struct NextCollector<'a> { - next: &'a mut Box, - op_name: String, - } - #[async_trait] - impl Collector for NextCollector<'_> { - async fn collect( - &mut self, - out: StreamOutput, - ctx: &mut TaskContext, - ) -> anyhow::Result<()> { - match out { - StreamOutput::Forward(b) => { - self.next - .process_event( - 0, - TrackedEvent::control(StreamEvent::Data(b)), - ctx, - ) - .await?; - } - StreamOutput::Watermark(wm) => { - self.next - .process_event( - 0, - TrackedEvent::control(StreamEvent::Watermark(wm)), - ctx, - ) - .await?; - } - StreamOutput::Keyed(_, _) | StreamOutput::Broadcast(_) => { - return Err(anyhow::anyhow!( - "Topology Violation: Keyed or Broadcast output emitted in the middle of chain by '{}'", - self.op_name - )); - } - } - Ok(()) - } - } - let mut collector = NextCollector { - next: &mut self.next, - op_name: self.operator.name().to_string(), - }; + let mut collector = ChainedCollector::new(self.next.as_mut(), self.operator.name()); self.operator .process_data(input_idx, batch, ctx, &mut collector) .await?; Ok(false) } StreamEvent::Watermark(wm) => { - struct NextCollector<'a> { - next: &'a mut Box, - op_name: String, - } - #[async_trait] - impl Collector for NextCollector<'_> { - async fn collect( - &mut self, - out: StreamOutput, - ctx: &mut TaskContext, - ) -> anyhow::Result<()> { - match out { - StreamOutput::Forward(b) => { - self.next - .process_event( - 0, - TrackedEvent::control(StreamEvent::Data(b)), - ctx, - ) - .await?; - } - StreamOutput::Watermark(wm) => { - self.next - .process_event( - 0, - TrackedEvent::control(StreamEvent::Watermark(wm)), - ctx, - ) - .await?; - } - StreamOutput::Keyed(_, _) | StreamOutput::Broadcast(_) => { - return Err(anyhow::anyhow!( - "Topology Violation: Keyed or Broadcast output emitted in the middle of chain by '{}'", - self.op_name - )); - } - } - Ok(()) - } - } - let mut collector = NextCollector { - next: &mut self.next, - op_name: self.operator.name().to_string(), - }; + let mut collector = ChainedCollector::new(self.next.as_mut(), self.operator.name()); self.operator .process_watermark(wm, ctx, &mut collector) .await?; @@ -242,8 +201,9 @@ impl OperatorDrive for IntermediateDriver { match &cmd { ControlCommand::TriggerCheckpoint { barrier } => { - let b: CheckpointBarrier = barrier.clone().into(); - self.operator.snapshot_state(b, ctx).await?; + self.operator + .snapshot_state(barrier.clone().into(), ctx) + .await?; } ControlCommand::Commit { epoch } => { self.operator.commit_checkpoint(*epoch, ctx).await?; @@ -266,12 +226,22 @@ impl OperatorDrive for IntermediateDriver { async fn on_close(&mut self, ctx: &mut TaskContext) -> Result<(), RunError> { let close_outs = self.operator.on_close(ctx).await?; - self.dispatch_outputs(close_outs, ctx).await?; + let mut collector = ChainedCollector::new(self.next.as_mut(), self.operator.name()); + + // 复用 Collector 处理 on_close 产生的数据 + for out in close_outs { + collector.collect(out, ctx).await?; + } + self.next.on_close(ctx).await?; Ok(()) } } +// ============================================================================ +// Tail Driver (End of the Chain) +// ============================================================================ + pub struct TailDriver { operator: Box, } @@ -281,22 +251,6 @@ impl TailDriver { Self { operator } } - async fn dispatch_outputs( - &mut self, - outputs: Vec, - ctx: &mut TaskContext, - ) -> Result<(), RunError> { - for out in outputs { - match out { - StreamOutput::Forward(b) => ctx.collect(b).await?, - StreamOutput::Keyed(hash, b) => ctx.collect_keyed(hash, b).await?, - StreamOutput::Broadcast(b) => ctx.collect(b).await?, - StreamOutput::Watermark(wm) => ctx.broadcast(StreamEvent::Watermark(wm)).await?, - } - } - Ok(()) - } - async fn forward_signal( &mut self, event: StreamEvent, @@ -327,52 +281,14 @@ impl OperatorDrive for TailDriver { ) -> Result { match tracked.event { StreamEvent::Data(batch) => { - struct FinalCollector; - #[async_trait] - impl Collector for FinalCollector { - async fn collect( - &mut self, - out: StreamOutput, - ctx: &mut TaskContext, - ) -> anyhow::Result<()> { - match out { - StreamOutput::Forward(b) => ctx.collect(b).await?, - StreamOutput::Keyed(hash, b) => ctx.collect_keyed(hash, b).await?, - StreamOutput::Broadcast(b) => ctx.collect(b).await?, - StreamOutput::Watermark(wm) => { - ctx.broadcast(StreamEvent::Watermark(wm)).await? - } - } - Ok(()) - } - } - let mut collector = FinalCollector; + let mut collector = TaskCollector; self.operator .process_data(input_idx, batch, ctx, &mut collector) .await?; Ok(false) } StreamEvent::Watermark(wm) => { - struct FinalCollector; - #[async_trait] - impl Collector for FinalCollector { - async fn collect( - &mut self, - out: StreamOutput, - ctx: &mut TaskContext, - ) -> anyhow::Result<()> { - match out { - StreamOutput::Forward(b) => ctx.collect(b).await?, - StreamOutput::Keyed(hash, b) => ctx.collect_keyed(hash, b).await?, - StreamOutput::Broadcast(b) => ctx.collect(b).await?, - StreamOutput::Watermark(wm) => { - ctx.broadcast(StreamEvent::Watermark(wm)).await? - } - } - Ok(()) - } - } - let mut collector = FinalCollector; + let mut collector = TaskCollector; self.operator .process_watermark(wm, ctx, &mut collector) .await?; @@ -422,7 +338,11 @@ impl OperatorDrive for TailDriver { async fn on_close(&mut self, ctx: &mut TaskContext) -> Result<(), RunError> { let close_outs = self.operator.on_close(ctx).await?; - self.dispatch_outputs(close_outs, ctx).await?; + let mut collector = TaskCollector; + + for out in close_outs { + collector.collect(out, ctx).await?; + } Ok(()) } } From 9c6bce6b19563470ab0a4c0d81ec371434437805 Mon Sep 17 00:00:00 2001 From: luoluoyuyu Date: Mon, 20 Apr 2026 02:46:35 +0800 Subject: [PATCH 25/26] update --- src/sql/schema/kafka_operator_config.rs | 13 +- tests/integration/framework/kafka_manager.py | 15 +- .../test/wasm/python_sdk/conftest.py | 25 +- .../test/wasm/python_sdk/test_data_flow.py | 13 +- .../python_sdk/test_streaming_sql_kafka.py | 256 ++++++++++++++++++ 5 files changed, 311 insertions(+), 11 deletions(-) create mode 100644 tests/integration/test/wasm/python_sdk/test_streaming_sql_kafka.py diff --git a/src/sql/schema/kafka_operator_config.rs b/src/sql/schema/kafka_operator_config.rs index d9251310..d87dda8f 100644 --- a/src/sql/schema/kafka_operator_config.rs +++ b/src/sql/schema/kafka_operator_config.rs @@ -24,6 +24,9 @@ use crate::sql::common::formats::{ use crate::sql::common::with_option_keys as opt; use crate::sql::schema::table_role::TableRole; +const STREAMING_JOB_OPTION_CHECKPOINT_INTERVAL: &str = "checkpoint.interval"; +const STREAMING_JOB_OPTION_PARALLELISM: &str = "parallelism"; + fn sql_format_to_proto(fmt: &SqlFormat) -> DFResult { match fmt { SqlFormat::Json(j) => Ok(FormatConfig { @@ -194,7 +197,10 @@ pub fn build_kafka_proto_config( }; let group_id_prefix = options.pull_opt_str(opt::KAFKA_GROUP_ID_PREFIX)?; - let client_configs = options.drain_remaining_string_values()?; + let mut client_configs = options.drain_remaining_string_values()?; + // Streaming job-level options are parsed by planner/coordinator, not Kafka client. + client_configs.remove(STREAMING_JOB_OPTION_CHECKPOINT_INTERVAL); + client_configs.remove(STREAMING_JOB_OPTION_PARALLELISM); Ok(ProtoConfig::KafkaSource(KafkaSourceConfig { topic, @@ -242,7 +248,10 @@ pub fn build_kafka_proto_config( None => options.pull_opt_str(opt::KAFKA_TIMESTAMP_FIELD_LEGACY)?, }; - let client_configs = options.drain_remaining_string_values()?; + let mut client_configs = options.drain_remaining_string_values()?; + // Streaming job-level options are parsed by planner/coordinator, not Kafka client. + client_configs.remove(STREAMING_JOB_OPTION_CHECKPOINT_INTERVAL); + client_configs.remove(STREAMING_JOB_OPTION_PARALLELISM); Ok(ProtoConfig::KafkaSink(KafkaSinkConfig { topic, diff --git a/tests/integration/framework/kafka_manager.py b/tests/integration/framework/kafka_manager.py index e495f638..3898fc7d 100644 --- a/tests/integration/framework/kafka_manager.py +++ b/tests/integration/framework/kafka_manager.py @@ -19,7 +19,7 @@ import logging import time -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, TypeVar import docker @@ -28,6 +28,8 @@ from confluent_kafka import TopicPartition as _new_topic_partition from confluent_kafka.admin import AdminClient, NewTopic +from .utils import find_free_port + logger = logging.getLogger(__name__) T = TypeVar("T") @@ -97,7 +99,10 @@ def __init__( config: Optional[KafkaConfig] = None, docker_client: Optional[docker.DockerClient] = None, ) -> None: - self.config = config or KafkaConfig() + if config is None: + host_port = find_free_port() + config = KafkaConfig(bootstrap_servers=f"127.0.0.1:{host_port}") + self.config = config # Dependency Injection: Allow passing an existing client, or create a lazy one. self._docker_client = docker_client @@ -173,7 +178,11 @@ def _ensure_container(self) -> None: self.docker_client.containers.run( image=self.config.image, name=self.config.container_name, - ports={f"{self.config.internal_port}/tcp": self.config.internal_port}, + ports={ + f"{self.config.internal_port}/tcp": int( + self.config.bootstrap_servers.rsplit(":", 1)[1] + ) + }, environment=self.config.environment_vars, detach=True, remove=True, # Auto-remove on stop diff --git a/tests/integration/test/wasm/python_sdk/conftest.py b/tests/integration/test/wasm/python_sdk/conftest.py index aa5d60c6..e0acd26e 100644 --- a/tests/integration/test/wasm/python_sdk/conftest.py +++ b/tests/integration/test/wasm/python_sdk/conftest.py @@ -54,10 +54,25 @@ def kafka_topics(kafka: KafkaDockerManager) -> str: return kafka.config.bootstrap_servers -def _sanitize_node_id(nodeid: str) -> str: - """Converts a pytest nodeid into a safe directory name.""" - clean_name = re.sub(r"[^\w\-]+", "-", nodeid) - return clean_name.strip("-") +def _sanitize_segment(segment: str) -> str: + clean = re.sub(r"[^\w\-]+", "_", segment).strip("_") + return clean or "unknown" + + +def _nodeid_to_workspace_path(nodeid: str) -> str: + """ + Convert pytest nodeid into a readable nested path under target/. + + Example: + test/wasm/python_sdk/test_data_flow.py::TestDataFlow::test_single_word_counting + -> + test/wasm/python_sdk/test_data_flow/TestDataFlow/test_single_word_counting + """ + parts = nodeid.split("::") + file_part = Path(parts[0]).with_suffix("") + file_segments = [_sanitize_segment(seg) for seg in file_part.parts] + extra_segments = [_sanitize_segment(seg) for seg in parts[1:]] + return str(Path(*file_segments, *extra_segments)) @pytest.fixture @@ -66,7 +81,7 @@ def fs_server(request: pytest.FixtureRequest) -> Generator[FunctionStreamInstanc Function-scoped FunctionStream instance. Uses Context Manager to ensure SIGKILL and workspace cleanup. """ - test_name = _sanitize_node_id(request.node.nodeid) + test_name = _nodeid_to_workspace_path(request.node.nodeid) with FunctionStreamInstance(test_name=test_name) as instance: yield instance diff --git a/tests/integration/test/wasm/python_sdk/test_data_flow.py b/tests/integration/test/wasm/python_sdk/test_data_flow.py index 9e9532a2..7fc89d7f 100644 --- a/tests/integration/test/wasm/python_sdk/test_data_flow.py +++ b/tests/integration/test/wasm/python_sdk/test_data_flow.py @@ -74,6 +74,13 @@ def consume_messages( deadline = time.time() + timeout try: + logger.info( + "Start consuming topic=%s expected_count=%d timeout=%.1fs bootstrap=%s", + topic, + expected_count, + timeout, + bootstrap, + ) while len(collected) < expected_count and time.time() < deadline: msg = consumer.poll(timeout=POLL_INTERVAL_S) if msg is None: @@ -85,11 +92,15 @@ def consume_messages( payload = msg.value().decode("utf-8") collected.append(json.loads(payload)) + logger.info("Consumed topic=%s count=%d payload=%s", topic, len(collected), payload) finally: consumer.close() if len(collected) < expected_count: - raise TimeoutError(f"Expected {expected_count} messages, received {len(collected)}") + raise TimeoutError( + f"Expected {expected_count} messages, received {len(collected)}. " + f"topic={topic}, collected={collected}" + ) return collected diff --git a/tests/integration/test/wasm/python_sdk/test_streaming_sql_kafka.py b/tests/integration/test/wasm/python_sdk/test_streaming_sql_kafka.py new file mode 100644 index 00000000..dcf5df18 --- /dev/null +++ b/tests/integration/test/wasm/python_sdk/test_streaming_sql_kafka.py @@ -0,0 +1,256 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime as dt +import json +import time +import uuid +from typing import Any, Dict, List + +from .test_data_flow import consume_messages, produce_messages + + +def _uid(prefix: str) -> str: + return f"{prefix}_{uuid.uuid4().hex[:8]}" + + +def _sql_ok(fs_server: Any, sql: str) -> Any: + resp = fs_server.execute_sql(sql) + assert resp.status_code == 200, f"SQL failed: {sql}\nstatus={resp.status_code}\nmsg={resp.message}" + return resp + + +class TestStreamingSqlKafka: + @staticmethod + def _create_impression_source(fs_server: Any, source_name: str, in_topic: str, bootstrap: str) -> None: + _sql_ok( + fs_server, + f""" + CREATE TABLE {source_name} ( + impression_id VARCHAR, + ad_id BIGINT, + campaign_id BIGINT, + user_id VARCHAR, + impression_time TIMESTAMP NOT NULL, + WATERMARK FOR impression_time AS impression_time - INTERVAL '1' SECOND + ) WITH ( + 'connector' = 'kafka', + 'topic' = '{in_topic}', + 'format' = 'json', + 'scan.startup.mode' = 'earliest', + 'bootstrap.servers' = '{bootstrap}' + ); + """, + ) + + @staticmethod + def _create_click_source(fs_server: Any, source_name: str, in_topic: str, bootstrap: str) -> None: + _sql_ok( + fs_server, + f""" + CREATE TABLE {source_name} ( + click_id VARCHAR, + impression_id VARCHAR, + ad_id BIGINT, + click_time TIMESTAMP NOT NULL, + WATERMARK FOR click_time AS click_time - INTERVAL '1' SECOND + ) WITH ( + 'connector' = 'kafka', + 'topic' = '{in_topic}', + 'format' = 'json', + 'scan.startup.mode' = 'earliest', + 'bootstrap.servers' = '{bootstrap}' + ); + """, + ) + + def test_tumble_window_with_kafka_produce_consume( + self, + fs_server: Any, + kafka: Any, + kafka_topics: str, + ) -> None: + source_name = _uid("ad_impressions_src") + stream_name = _uid("metric_tumble_impressions") + in_topic = _uid("topic_in") + out_topic = _uid("topic_out") + + kafka.create_topics_if_not_exist([in_topic, out_topic]) + + _sql_ok( + fs_server, + f""" + CREATE TABLE {source_name} ( + impression_id VARCHAR, + campaign_id BIGINT, + impression_time TIMESTAMP NOT NULL, + WATERMARK FOR impression_time AS impression_time - INTERVAL '1' SECOND + ) WITH ( + 'connector' = 'kafka', + 'topic' = '{in_topic}', + 'format' = 'json', + 'scan.startup.mode' = 'earliest', + 'bootstrap.servers' = '{kafka_topics}' + ); + """, + ) + + _sql_ok( + fs_server, + f""" + CREATE STREAMING TABLE {stream_name} WITH ( + 'connector' = 'kafka', + 'topic' = '{out_topic}', + 'format' = 'json', + 'bootstrap.servers' = '{kafka_topics}' + ) AS + SELECT + TUMBLE(INTERVAL '2' SECOND) AS time_window, + campaign_id, + COUNT(*) AS total_impressions + FROM {source_name} + GROUP BY 1, campaign_id; + """, + ) + + now = dt.datetime.now(dt.timezone.utc) + base = now.replace(microsecond=0) - dt.timedelta(seconds=8) + old_window_msgs: List[Dict[str, Any]] = [ + { + "impression_id": "i-1", + "campaign_id": 1001, + "impression_time": (base + dt.timedelta(milliseconds=100)).isoformat(), + }, + { + "impression_id": "i-2", + "campaign_id": 1001, + "impression_time": (base + dt.timedelta(milliseconds=500)).isoformat(), + }, + { + "impression_id": "i-3", + "campaign_id": 1002, + "impression_time": (base + dt.timedelta(milliseconds=900)).isoformat(), + }, + ] + advance_wm = { + "impression_id": "i-4", + "campaign_id": 9999, + "impression_time": dt.datetime.now(dt.timezone.utc).isoformat(), + } + + produce_messages(kafka_topics, in_topic, [json.dumps(x) for x in old_window_msgs + [advance_wm]]) + time.sleep(1.0) + + records = consume_messages(kafka_topics, out_topic, expected_count=2, timeout=15.0) + got = {(int(r["campaign_id"]), int(r["total_impressions"])) for r in records} + assert got == {(1001, 2), (1002, 1)} + + _sql_ok(fs_server, f"DROP STREAMING TABLE {stream_name};") + + def test_hop_window_with_where_filter( + self, + fs_server: Any, + kafka: Any, + kafka_topics: str, + ) -> None: + source_name = _uid("ad_impressions_src") + stream_name = _uid("metric_hop_uv") + in_topic = _uid("topic_in") + out_topic = _uid("topic_out") + kafka.create_topics_if_not_exist([in_topic, out_topic]) + self._create_impression_source(fs_server, source_name, in_topic, kafka_topics) + + _sql_ok( + fs_server, + f""" + CREATE STREAMING TABLE {stream_name} WITH ( + 'connector' = 'kafka', + 'topic' = '{out_topic}', + 'format' = 'json', + 'bootstrap.servers' = '{kafka_topics}' + ) AS + SELECT + HOP(INTERVAL '1' SECOND, INTERVAL '4' SECOND) AS time_window, + ad_id, + COUNT(*) AS kept_rows + FROM {source_name} + WHERE campaign_id = 2001 + GROUP BY 1, ad_id; + """, + ) + + now = dt.datetime.now(dt.timezone.utc) + base = now.replace(microsecond=0) - dt.timedelta(seconds=8) + msgs = [ + {"impression_id": "h1", "ad_id": 11, "campaign_id": 2001, "user_id": "u1", + "impression_time": (base + dt.timedelta(milliseconds=100)).isoformat()}, + {"impression_id": "h2", "ad_id": 11, "campaign_id": 2002, "user_id": "u2", + "impression_time": (base + dt.timedelta(milliseconds=300)).isoformat()}, + {"impression_id": "h3", "ad_id": 12, "campaign_id": 2001, "user_id": "u3", + "impression_time": (base + dt.timedelta(milliseconds=600)).isoformat()}, + {"impression_id": "h4", "ad_id": 999, "campaign_id": 9999, "user_id": "wm", + "impression_time": dt.datetime.now(dt.timezone.utc).isoformat()}, + ] + produce_messages(kafka_topics, in_topic, [json.dumps(x) for x in msgs]) + rows = consume_messages(kafka_topics, out_topic, expected_count=2, timeout=15.0) + got = {(int(r["ad_id"]), int(r["kept_rows"])) for r in rows} + assert got == {(11, 1), (12, 1)} + _sql_ok(fs_server, f"DROP STREAMING TABLE {stream_name};") + + def test_session_window_user_activity( + self, + fs_server: Any, + kafka: Any, + kafka_topics: str, + ) -> None: + source_name = _uid("ad_impressions_src") + stream_name = _uid("metric_session_impr") + in_topic = _uid("topic_in") + out_topic = _uid("topic_out") + kafka.create_topics_if_not_exist([in_topic, out_topic]) + self._create_impression_source(fs_server, source_name, in_topic, kafka_topics) + + _sql_ok( + fs_server, + f""" + CREATE STREAMING TABLE {stream_name} WITH ( + 'connector' = 'kafka', + 'topic' = '{out_topic}', + 'format' = 'json', + 'bootstrap.servers' = '{kafka_topics}' + ) AS + SELECT + SESSION(INTERVAL '2' SECOND) AS time_window, + user_id, + COUNT(*) AS impressions_in_session + FROM {source_name} + GROUP BY 1, user_id; + """, + ) + + now = dt.datetime.now(dt.timezone.utc) + base = now.replace(microsecond=0) - dt.timedelta(seconds=10) + msgs = [ + {"impression_id": "s1", "ad_id": 1, "campaign_id": 1, "user_id": "uA", + "impression_time": (base + dt.timedelta(milliseconds=100)).isoformat()}, + {"impression_id": "s2", "ad_id": 1, "campaign_id": 1, "user_id": "uA", + "impression_time": (base + dt.timedelta(milliseconds=900)).isoformat()}, + {"impression_id": "s3", "ad_id": 2, "campaign_id": 1, "user_id": "uB", + "impression_time": (base + dt.timedelta(milliseconds=1200)).isoformat()}, + {"impression_id": "s4", "ad_id": 999, "campaign_id": 9999, "user_id": "wm", + "impression_time": dt.datetime.now(dt.timezone.utc).isoformat()}, + ] + produce_messages(kafka_topics, in_topic, [json.dumps(x) for x in msgs]) + rows = consume_messages(kafka_topics, out_topic, expected_count=2, timeout=15.0) + got = {(r["user_id"], int(r["impressions_in_session"])) for r in rows} + assert got == {("uA", 2), ("uB", 1)} + _sql_ok(fs_server, f"DROP STREAMING TABLE {stream_name};") \ No newline at end of file From 7a95c679a90eaf6d377690d6a82b3b6691c71c9e Mon Sep 17 00:00:00 2001 From: luoluoyuyu Date: Mon, 20 Apr 2026 21:47:43 +0800 Subject: [PATCH 26/26] update --- src/config/streaming_job.rs | 18 ++ src/runtime/streaming/execution/pipeline.rs | 6 +- .../streaming/execution/source_driver.rs | 6 +- src/runtime/streaming/job/job_manager.rs | 233 ++++++++++-------- src/runtime/streaming/job/models.rs | 4 +- src/server/initializer.rs | 2 + 6 files changed, 156 insertions(+), 113 deletions(-) diff --git a/src/config/streaming_job.rs b/src/config/streaming_job.rs index 6ea45609..0b0d1cde 100644 --- a/src/config/streaming_job.rs +++ b/src/config/streaming_job.rs @@ -15,6 +15,7 @@ use serde::{Deserialize, Serialize}; pub const DEFAULT_CHECKPOINT_INTERVAL_MS: u64 = 60 * 1000; pub const DEFAULT_PIPELINE_PARALLELISM: u32 = 1; pub const DEFAULT_KEY_BY_PARALLELISM: u32 = 1; +pub const DEFAULT_JOB_MANAGER_CONTROL_PLANE_THREADS: u32 = 1; #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct StreamingJobConfig { @@ -25,6 +26,10 @@ pub struct StreamingJobConfig { /// Physical parallelism for KeyBy / key-extraction operators in planned streaming graphs. #[serde(default)] pub key_by_parallelism: Option, + #[serde(default)] + pub job_manager_control_plane_threads: Option, + #[serde(default)] + pub job_manager_data_plane_threads: Option, } #[derive(Debug, Clone, Copy)] @@ -32,10 +37,15 @@ pub struct ResolvedStreamingJobConfig { pub checkpoint_interval_ms: u64, pub pipeline_parallelism: u32, pub key_by_parallelism: u32, + pub job_manager_control_plane_threads: u32, + pub job_manager_data_plane_threads: u32, } impl StreamingJobConfig { pub fn resolve(&self) -> ResolvedStreamingJobConfig { + let cpu_threads = std::thread::available_parallelism() + .map(|n| n.get() as u32) + .unwrap_or(1); ResolvedStreamingJobConfig { checkpoint_interval_ms: self .checkpoint_interval_ms @@ -49,6 +59,14 @@ impl StreamingJobConfig { .key_by_parallelism .filter(|&p| p > 0) .unwrap_or(DEFAULT_KEY_BY_PARALLELISM), + job_manager_control_plane_threads: self + .job_manager_control_plane_threads + .filter(|&p| p > 0) + .unwrap_or(DEFAULT_JOB_MANAGER_CONTROL_PLANE_THREADS), + job_manager_data_plane_threads: self + .job_manager_data_plane_threads + .filter(|&p| p > 0) + .unwrap_or(cpu_threads), } } } diff --git a/src/runtime/streaming/execution/pipeline.rs b/src/runtime/streaming/execution/pipeline.rs index a9a2d102..91309a48 100644 --- a/src/runtime/streaming/execution/pipeline.rs +++ b/src/runtime/streaming/execution/pipeline.rs @@ -10,7 +10,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use tokio::sync::mpsc::Receiver; +use tokio::sync::mpsc::UnboundedReceiver; use tokio_stream::{StreamExt, StreamMap}; use tracing::{Instrument, info, info_span}; @@ -33,7 +33,7 @@ pub struct Pipeline { chain_head: Box, ctx: TaskContext, inboxes: Vec, - control_rx: Receiver, + control_rx: UnboundedReceiver, wm_tracker: WatermarkTracker, barrier_aligner: BarrierAligner, @@ -45,7 +45,7 @@ impl Pipeline { operators: Vec>, ctx: TaskContext, inboxes: Vec, - control_rx: Receiver, + control_rx: UnboundedReceiver, ) -> Result { let input_count = inboxes.len(); let chain_head = ChainBuilder::build(operators) diff --git a/src/runtime/streaming/execution/source_driver.rs b/src/runtime/streaming/execution/source_driver.rs index 9f403053..b4e7d327 100644 --- a/src/runtime/streaming/execution/source_driver.rs +++ b/src/runtime/streaming/execution/source_driver.rs @@ -10,7 +10,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use tokio::sync::mpsc::Receiver; +use tokio::sync::mpsc::UnboundedReceiver; use tokio::time::{Instant, sleep}; use tracing::{Instrument, info, info_span, warn}; @@ -28,7 +28,7 @@ pub struct SourceDriver { operator: Box, chain_head: Option>, ctx: TaskContext, - control_rx: Receiver, + control_rx: UnboundedReceiver, } impl SourceDriver { @@ -36,7 +36,7 @@ impl SourceDriver { operator: Box, chain_head: Option>, ctx: TaskContext, - control_rx: Receiver, + control_rx: UnboundedReceiver, ) -> Self { Self { operator, diff --git a/src/runtime/streaming/job/job_manager.rs b/src/runtime/streaming/job/job_manager.rs index 130b76b9..a9bc546f 100644 --- a/src/runtime/streaming/job/job_manager.rs +++ b/src/runtime/streaming/job/job_manager.rs @@ -17,10 +17,11 @@ use std::sync::{Arc, Mutex, OnceLock, RwLock}; use std::time::Duration; use anyhow::{Context, Result, anyhow, bail, ensure}; -use tokio::sync::mpsc; +use tokio::sync::mpsc::{self, UnboundedSender}; use tokio::task::JoinHandle as TokioJoinHandle; +use tokio::time::Instant; use tokio_stream::wrappers::ReceiverStream; -use tracing::{debug, error, info, warn}; +use tracing::{error, info, warn}; use protocol::function_stream_graph::{ChainedOperator, FsProgram}; use protocol::storage::{ @@ -79,6 +80,8 @@ pub struct StateConfig { pub soft_limit_ratio: f64, pub checkpoint_interval_ms: u64, pub pipeline_parallelism: u32, + pub job_manager_control_plane_threads: u32, + pub job_manager_data_plane_threads: u32, /// Total bytes shared by all [`crate::runtime::streaming::state::OperatorStateStore`] (global pool). pub per_operator_memory_bytes: u64, } @@ -91,6 +94,10 @@ impl Default for StateConfig { soft_limit_ratio: 0.7, checkpoint_interval_ms: DEFAULT_CHECKPOINT_INTERVAL_MS, pipeline_parallelism: DEFAULT_PIPELINE_PARALLELISM, + job_manager_control_plane_threads: 2, + job_manager_data_plane_threads: std::thread::available_parallelism() + .map(|n| n.get() as u32) + .unwrap_or(1), per_operator_memory_bytes: DEFAULT_OPERATOR_STATE_STORE_MEMORY_BYTES, } } @@ -128,6 +135,8 @@ pub struct JobManager { io_pool: Mutex>, state_base_dir: PathBuf, state_config: StateConfig, + control_rt: Arc, + data_rt: Arc, } struct PreparedChain { @@ -142,13 +151,14 @@ enum PipelineRunner { struct CheckpointCoordinatorConfig { job_id: String, - source_control_txs: Vec>, - all_pipeline_control_txs: Vec>, + source_control_txs: Vec>, + all_pipeline_control_txs: Vec>, job_master_rx: mpsc::Receiver, expected_pipeline_ids: HashSet, interval_ms: u64, start_epoch: u64, job_state_dir: PathBuf, + timeout: Duration, } impl PipelineRunner { @@ -185,6 +195,18 @@ impl JobManager { state_base_dir: impl AsRef, state_config: StateConfig, ) -> Result { + let control_rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(state_config.job_manager_control_plane_threads.max(1) as usize) + .thread_name("fs-control-plane") + .enable_all() + .build() + .context("Failed to initialize control runtime")?; + let data_rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(state_config.job_manager_data_plane_threads.max(1) as usize) + .thread_name("fs-data-plane") + .enable_all() + .build() + .context("Failed to initialize data runtime")?; let metrics = Arc::new(NoopMetricsCollector); let (io_pool, io_manager_client) = IoPool::try_new( state_config.max_background_spills, @@ -200,6 +222,8 @@ impl JobManager { io_pool: Mutex::new(Some(io_pool)), state_base_dir: state_base_dir.as_ref().to_path_buf(), state_config, + control_rt: Arc::new(control_rt), + data_rt: Arc::new(data_rt), }) } @@ -299,6 +323,7 @@ impl JobManager { interval_ms, start_epoch: safe_epoch + 1, job_state_dir: job_state_dir.clone(), + timeout: Duration::from_millis(interval_ms.max(1) * 3), }); let graph = PhysicalExecutionGraph { @@ -322,7 +347,7 @@ impl JobManager { let control_senders = self.extract_control_senders(job_id)?; for tx in control_senders { - let _ = tx.send(ControlCommand::Stop { mode: mode.clone() }).await; + let _ = tx.send(ControlCommand::Stop { mode: mode.clone() }); } info!(job_id = %job_id, mode = ?mode, "Job stop signal dispatched."); @@ -474,7 +499,10 @@ impl JobManager { StreamingJobRollupStatus::Reconciling } } - fn extract_control_senders(&self, job_id: &str) -> Result>> { + fn extract_control_senders( + &self, + job_id: &str, + ) -> Result>> { let jobs_guard = self .active_jobs .read() @@ -541,7 +569,7 @@ impl JobManager { pipeline_id ); - let (control_tx, control_rx) = mpsc::channel(64); + let (control_tx, control_rx) = mpsc::unbounded_channel(); let status = Arc::new(RwLock::new(PipelineStatus::Initializing)); let subtask_index = 0; @@ -593,9 +621,7 @@ impl JobManager { ) }; - let handle = self - .spawn_worker_thread(job_id, pipeline_id, runner, Arc::clone(&status)) - .with_context(|| format!("Failed to spawn OS thread for pipeline {}", pipeline_id))?; + let handle = self.spawn_worker_task(job_id, pipeline_id, runner, Arc::clone(&status)); let pipeline = PhysicalPipeline { pipeline_id, @@ -637,55 +663,39 @@ impl JobManager { }) } - fn spawn_worker_thread( + fn spawn_worker_task( &self, job_id: String, pipeline_id: u32, runner: PipelineRunner, status: Arc>, - ) -> Result> { - let thread_name = format!("Task-{job_id}-{pipeline_id}"); - - let handle = std::thread::Builder::new() - .name(thread_name) - .spawn(move || { - if let Ok(mut st) = status.write() { - *st = PipelineStatus::Running; - } + ) -> TokioJoinHandle<()> { + self.data_rt.spawn(async move { + if let Ok(mut st) = status.write() { + *st = PipelineStatus::Running; + } - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .expect("Failed to build current-thread Tokio runtime"); - - let execution_result = - std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { - rt.block_on(async move { - runner - .run() - .await - .map_err(|e| anyhow!("Execution failed: {e}")) - }) - })); - - Self::handle_pipeline_exit(&job_id, pipeline_id, execution_result, &status); - })?; + let execution_result = runner + .run() + .await + .map_err(|e| anyhow!("Execution failed: {e}")); - Ok(handle) + Self::handle_pipeline_exit(&job_id, pipeline_id, execution_result, &status); + }) } fn handle_pipeline_exit( job_id: &str, pipeline_id: u32, - thread_result: std::thread::Result>, + result: Result<()>, status: &RwLock, ) { - let (final_status, is_fatal) = match thread_result { - Ok(Ok(_)) => { + let (final_status, is_fatal) = match result { + Ok(_) => { info!(job_id = %job_id, pipeline_id = pipeline_id, "Pipeline finished gracefully."); (PipelineStatus::Finished, false) } - Ok(Err(e)) => { + Err(e) => { error!(job_id = %job_id, pipeline_id = pipeline_id, error = %e, "Pipeline failed."); ( PipelineStatus::Failed { @@ -695,16 +705,6 @@ impl JobManager { true, ) } - Err(_) => { - error!(job_id = %job_id, pipeline_id = pipeline_id, "Pipeline thread panicked!"); - ( - PipelineStatus::Failed { - error: "Unexpected panic in task thread".into(), - is_panic: true, - }, - true, - ) - } }; if let Ok(mut st) = status.write() { @@ -724,17 +724,18 @@ impl JobManager { &self, cfg: CheckpointCoordinatorConfig, ) -> TokioJoinHandle<()> { - let CheckpointCoordinatorConfig { - job_id, - source_control_txs, - all_pipeline_control_txs, - mut job_master_rx, - expected_pipeline_ids, - interval_ms, - start_epoch, - job_state_dir, - } = cfg; - tokio::spawn(async move { + self.control_rt.spawn(async move { + let CheckpointCoordinatorConfig { + job_id, + mut source_control_txs, + all_pipeline_control_txs, + mut job_master_rx, + expected_pipeline_ids, + interval_ms, + start_epoch, + job_state_dir, + timeout, + } = cfg; if interval_ms == 0 { info!(job_id = %job_id, "Checkpoint disabled for this job"); return; @@ -744,17 +745,19 @@ impl JobManager { interval.tick().await; let mut current_epoch: u64 = start_epoch; - let mut pending_checkpoints: HashMap> = HashMap::new(); - let mut source_reports: HashMap> = HashMap::new(); - - async fn broadcast_checkpoint_phase2( - txs: &[mpsc::Sender], - cmd: ControlCommand, - ) { - for tx in txs { - let _ = tx.send(cmd.clone()).await; - } + struct PendingCheckpoint { + epoch: u64, + missing_acks: HashSet, + start_time: Instant, + source_reports: Vec, } + let mut active_checkpoint: Option = None; + + let broadcast_cmd = |cmd: ControlCommand| { + for tx in &all_pipeline_control_txs { + let _ = tx.send(cmd.clone()); + } + }; loop { tokio::select! { @@ -767,23 +770,23 @@ impl JobManager { epoch, source_payloads, } => { - if !source_payloads.is_empty() { - source_reports - .entry(epoch) - .or_default() - .extend(source_payloads); - } - if let Some(pending_set) = pending_checkpoints.get_mut(&epoch) { - pending_set.remove(&pipeline_id); + if let Some(pending) = &mut active_checkpoint { + if pending.epoch != epoch { + continue; + } + pending.missing_acks.remove(&pipeline_id); + if !source_payloads.is_empty() { + pending.source_reports.extend(source_payloads); + } - if pending_set.is_empty() { + if pending.missing_acks.is_empty() { info!( job_id = %job_id, epoch = epoch, "Checkpoint Epoch is GLOBALLY COMPLETED (phase 1); persisting metadata and notifying operators (phase 2)" ); - let payloads = source_reports.remove(&epoch).unwrap_or_default(); - let kf = decode_kafka_checkpoints_from_source_payloads(payloads, epoch); + let completed = active_checkpoint.take().expect("active checkpoint exists"); + let kf = decode_kafka_checkpoints_from_source_payloads(completed.source_reports, epoch); let epoch_u32 = u32::try_from(epoch).unwrap_or(u32::MAX); let mut catalog_ok = true; @@ -813,33 +816,50 @@ impl JobManager { } else { ControlCommand::AbortCheckpoint { epoch: epoch_u32 } }; - broadcast_checkpoint_phase2(&all_pipeline_control_txs, phase2).await; - - pending_checkpoints.remove(&epoch); + broadcast_cmd(phase2); } } } JobMasterEvent::CheckpointDecline { pipeline_id, epoch, reason } => { - error!( - job_id = %job_id, epoch = epoch, pipeline_id = pipeline_id, - reason = %reason, "Checkpoint FAILED!" - ); - if pending_checkpoints.remove(&epoch).is_some() { - source_reports.remove(&epoch); - let epoch_u32 = u32::try_from(epoch).unwrap_or(u32::MAX); - broadcast_checkpoint_phase2( - &all_pipeline_control_txs, - ControlCommand::AbortCheckpoint { epoch: epoch_u32 }, - ) - .await; + if let Some(pending) = &active_checkpoint + && pending.epoch == epoch + { + error!( + job_id = %job_id, epoch = epoch, pipeline_id = pipeline_id, + reason = %reason, "Checkpoint FAILED!" + ); + broadcast_cmd(ControlCommand::AbortCheckpoint { + epoch: u32::try_from(epoch).unwrap_or(u32::MAX), + }); + active_checkpoint = None; } } } } - _ = interval.tick(), if pending_checkpoints.is_empty() => { + _ = interval.tick() => { + if let Some(pending) = &active_checkpoint { + if pending.start_time.elapsed() > timeout { + warn!( + job_id = %job_id, + epoch = pending.epoch, + "Checkpoint timed out; aborting active epoch" + ); + broadcast_cmd(ControlCommand::AbortCheckpoint { + epoch: u32::try_from(pending.epoch).unwrap_or(u32::MAX), + }); + } else { + continue; + } + } + + source_control_txs.retain(|tx| !tx.is_closed()); + if source_control_txs.is_empty() { + info!(job_id = %job_id, "All source pipelines closed; checkpoint coordinator exiting"); + break; + } + info!(job_id = %job_id, epoch = current_epoch, "Triggering global Checkpoint Barrier."); - pending_checkpoints.insert(current_epoch, expected_pipeline_ids.clone()); let barrier = CheckpointBarrier { epoch: current_epoch as u32, @@ -847,13 +867,16 @@ impl JobManager { timestamp: std::time::SystemTime::now(), then_stop: false, }; + active_checkpoint = Some(PendingCheckpoint { + epoch: current_epoch, + missing_acks: expected_pipeline_ids.clone(), + start_time: Instant::now(), + source_reports: Vec::new(), + }); for tx in &source_control_txs { let cmd = ControlCommand::trigger_checkpoint(barrier); - if tx.send(cmd).await.is_err() { - debug!(job_id = %job_id, "Source disconnected. Shutting down coordinator."); - return; - } + let _ = tx.send(cmd); } current_epoch += 1; } diff --git a/src/runtime/streaming/job/models.rs b/src/runtime/streaming/job/models.rs index f4e2f280..e81649f2 100644 --- a/src/runtime/streaming/job/models.rs +++ b/src/runtime/streaming/job/models.rs @@ -13,11 +13,11 @@ use std::collections::HashMap; use std::fmt; use std::sync::{Arc, RwLock}; -use std::thread::JoinHandle; use std::time::Instant; use protocol::function_stream_graph::FsProgram; use tokio::sync::mpsc; +use tokio::task::JoinHandle; use crate::runtime::streaming::protocol::control::ControlCommand; @@ -78,7 +78,7 @@ pub struct PhysicalPipeline { pub pipeline_id: u32, pub handle: Option>, pub status: Arc>, - pub control_tx: mpsc::Sender, + pub control_tx: mpsc::UnboundedSender, } pub struct PhysicalExecutionGraph { diff --git a/src/server/initializer.rs b/src/server/initializer.rs index a20f2ffb..8a04608e 100644 --- a/src/server/initializer.rs +++ b/src/server/initializer.rs @@ -186,6 +186,8 @@ fn initialize_job_manager(config: &GlobalConfig) -> Result<()> { let state_config = StateConfig { checkpoint_interval_ms: job.checkpoint_interval_ms, pipeline_parallelism: job.pipeline_parallelism, + job_manager_control_plane_threads: job.job_manager_control_plane_threads, + job_manager_data_plane_threads: job.job_manager_data_plane_threads, per_operator_memory_bytes, ..StateConfig::default() };