FIR phase-shift + extract apply_raised_cosine_taper from get_chunk_with_margin#4563
FIR phase-shift + extract apply_raised_cosine_taper from get_chunk_with_margin#4563galenlynch wants to merge 5 commits into
apply_raised_cosine_taper from get_chunk_with_margin#4563Conversation
|
Test failures are unrelated to this PR |
|
@samuelgarcia @oliche can you take a look at this implementation? |
Adds a sinc-FIR alternative to the FFT-based PhaseShift path, and factors
the FFT-specific raised-cosine taper out of get_chunk_with_margin so
bounded-support FIR consumers don't have to pay for it.
Changes
-------
- PhaseShiftRecording(method="fft"|"fir", n_taps=32, output_dtype=None)
- method="fir": 32-tap Kaiser-windowed sinc, numba-jit with prange over time
- Per-channel kernels cached once per segment
- Int16-native fast path: reads int16 directly, writes float32 when
output_dtype=np.float32 (skips round-back-to-int16)
- FIR margin = n_taps // 2 (16 for 32-tap default), vs FFT path's 40 ms
- Default "fft" preserves existing behavior
- apply_raised_cosine_taper(data, margin, *, inplace=True) public function
in spikeinterface.core.time_series_tools
- get_chunk_with_margin(window_on_margin=True) emits DeprecationWarning
and delegates to apply_raised_cosine_taper; old behavior preserved
Whittaker-Shannon justification
-------------------------------
Both FFT and FIR paths are numerical realisations of the same ideal
sinc-interpolation implied by the sampling theorem. FFT does it spectrally
(phase ramp = DFT of full sinc kernel); FIR does it in time against a
Kaiser-windowed 32-tap truncation of the same sinc. Approximations: sinc
truncation (32 taps captures >99% of kernel energy for any d ∈ [0,1)) and
Kaiser windowing (-80 dB stopband, two orders of magnitude below NP's
~50 dB SNR). Measured 0.19% spike-band RMS vs FFT on real NP 2.0 data.
Performance (1M × 384 int16 AIND pipeline, PS → HP → CMR, 24-core host)
-----------------------------------------------------------------------
stock FFT FIR speedup
single-call get_traces 82.1 s 6.28 s 13.1×
same, f32 propagated 85.8 s 4.40 s 19.5×
TimeSeriesChunkExecutor n=24 4.98 s 1.62 s 3.1× on top
of CRE n=24
Memory (same pipeline, n_jobs=24, chunk=1s):
FFT peak RSS: 10.8 GB FIR peak RSS: 3.9 GB (2.8× less)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
for more information, see https://pre-commit.ci
b3f3784 to
301355a
Compare
|
LGTM 0.2% is not a great relative tol, but this is fine for spike sorting, and this implementation is very valuable for real-time applications ! But I wonder if for waveform extraction we shouldn't push a bit further the analysis. The mismatch with the true FFT implementation will be frequency dependent and worst around Nyquist (which I think is why on a synthetic it is higher than on real data). Here we would reassure neuroscientists by showing waveforms with both techniques applied to see if there is any distortion that could be picked up by downstream models ! Then we could confidently point the users to this method for both spike sorting and waveform extraction. |
The windowed-sinc kernel has a shift-dependent DC magnitude (Σh[k] = 0.9975…1.0000 for Kaiser β=8.6 on NPX inter-sample shifts), which produced a constant per-channel amplitude bias of up to ~0.3 % structured by ADC group. Dividing each channel's kernel by its own sum forces H(0)=1 exactly and, because the passband is essentially flat at the DC value before normalization, flattens the entire passband response. Worst-case RMS error vs the FFT reference drops from 0.42 % to 0.0007 % (600×) on realistic NPX 2.0 inter-sample shifts and a 1/√f signal spectrum. Per-channel amplitude bias drops from 0.30 % to 0.05 %. Phase and group delay are unaffected. Zero runtime cost — kernels are built once per segment. Standard practice for fractional-delay FIR design. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
With per-channel DC normalization the dominant error source is gone, and the remaining in-band ripple shrinks fast with tap count. At K=16 worst-case in-band RMS vs FFT is ≤0.02 % on a white test signal cut at the NPX analog passband (10 kHz) — ~180× below the ~5 µV Johnson thermal noise of a 150 kΩ NPX 2.0 electrode at brain temperature, i.e. below the physics-imposed noise floor of the recording chain. Result: 2× faster than the previous K=32 (~170× faster than the FFT path), with no measurable accuracy regression on any signal the analog frontend can produce. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…eoff The K-tap windowed sinc cannot be perfectly allpass: it has a flat passband up to ~0.33·fs at K=16, with a soft rolloff above. For NPX this lands exactly at the 10 kHz analog cutoff, so the FIR's rolloff only attenuates out-of-band noise the analog filter already removed (arguably a benefit). For higher-fs or broadband signals where content extends past ~0.33·fs, users should raise n_taps or use method="fft". Also explain that K=16 is matched to NPX 2.0's 12-bit ADC: the FIR's ~5000:1 effective precision sits right at the ADC's precision limit. Higher K buys more accuracy than the input data provides. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
oliche
left a comment
There was a problem hiding this comment.
I don't think we are trying to optimize for sidelobes nor frequency response here, we just want apodization to only modify the edges.
Ideally those are also complementary (ie we could reconstruct the signal by summing batch n-1 with batch n at the edges).
For this the hamming window is not a good choice as it has a non-zero value at its extrema.
I'd suggest a much simpler cosine window, as done in the original IBL implementation !
c939ca5 to
7cf5748
Compare
|
Thanks for taking a look @oliche! I don't understand your last comment, are you referring to the apodization in the FFT phase shift implementation (which I think is what IBL did)? If it was in response to the window commit, that was a dead-end investigation on my part on the FIR sinc function window — I was wondering why the 0.2% error was so high, and it turned out that it was because I was not normalizing the windowed sinc functions for different phase shifts. Whether or not 0.2% tolerance matters for biology is an interesting question and caused me to do some investigation and catch the normalization bug and reduce it to 0.0145% while reducing the number of FIR taps to 16 FIR taps. I think this is sufficient for Neuropixels 1.0 and 2.0. The reason is, briefly, because of the limited bandwidth of Neuropixels (0.3 Hz - 10 kHz for NPX 2.0), their input-referred noise, and their limited ADC precision. To recap the overall approach of the PR: the existing FFT approach to interpolation is equivalent to circular convolution of the signal with a periodic sinc function. This PR speeds that up by convolving with truncated sinc functions, but we first need to window them to prevent Gibbs ripples in the frequency domain and therefore distortion of the interpolated signal. The window function here does matter since Neuropixels have a pass band of 0.3 Hz - 10 kHz: we can use very small sinc functions if our window function does not attenuate the pass band. Hann and cosine windows have 0 at the endpoints and effectively waste two FIR taps as a result. To summarize my investigations: the FIR phase shift's in-band fidelity is ≥12-bit-equivalent (better than the NPX 2.0 ADC itself), its rolloff sits exactly where the analog anti-aliasing filter already rolled off, and its impact on real spike templates is ~200× below the Neuropixel's input-referred noise floor. This makes it safe for waveform extraction and arguably a good choice for everyone, not just real-time pipelines. The acquisition hardware has a limited passband of useful signalThe NPX 2.0 has a fs=30 kHz and pass band up to 10 kHz (after which the anti-aliasing filter attenuates electrode signal), so we need a window that has a flat pass band to ~0.33·fs and rolls off above. The FIR's rolloff above 10 kHz only attenuates content the analog filter already attenuated. The K=16 Kaiser-windowed sinc achieves just that with the smallest number of taps. We can compare the passband error to discretization noise for various bit depths, where NPX 2.0 has 12 bits, and see that with this Kaiser window K=12 → 7.5 bits, K=16 → 12.8 bits, K=32 only adds ~4 more (past the ADC's own quantization floor).
Filter fidelity in-band (analytical)Computed from the kernel coefficients via
Worst-case channel in the passband is 205 ppm = 12.3 bits, which is above the 244 ppm = 12-bit threshold of the NPX 2.0 ADC itself. What the FIR actually does to the recordingThe FIR's perturbation is multiplicative at each frequency: whatever sits at frequency f in the recording — spike content, LFP residual, thermal noise, amplifier flicker — gets multiplied by 1 ± ε(f), where ε(f) is exactly the analytical error spectrum |H_FIR(f) − e^{−iωd}| from above. So the FIR's data impact is independent of the recording's spectrum; only the filter's own error matters, and that's known a priori from the kernel. Per-frequency bounds (NPX-averaged, K=16):
So whatever the spectrum of any individual recording, the FIR preserves each frequency in the spike-content band to within ~100 ppm (0.01 %), and to within ~200 ppm (0.02 %) RMS across the full passband (NPX-averaged 145 ppm = 0.0145 %, worst-shift 205 ppm = 0.0205 %). For reference: 12-bit ADC quantization is 244 ppm (0.0244 %) — the FIR perturbs each frequency by less than the ADC's own discretization noise, over the entire band where signal lives. The end-to-end validation below — applied to a real recording's actual spectrum — is the empirical confirmation that this per-frequency picture propagates correctly through the rest of the preprocessing chain. End-to-end on real dataBoth pipelines (
Mean templates from the two pipelines are visually indistinguishable on every unit I plotted; the 50× amplified residual trace is just noise. Performance of 16 tap FIR filterMoving from 32 (main PR text) to 16 taps (this message) doesn't change real-world performance on my machine since it's memory bandwidth limited (dual-channel DDR5) but high performance machines with more memory bandwidth would benefit. Just for fun, here's a roofline plot of the different FIR taps vs the existing FFT phase shift algorithm:
And the throughput on my machine for the different algorithms themsleves (in reality
|







This PR introduces a faster time-domain approach to phase shifting neuropixels recording, which can be opted-into, along side the exist FFT-approach. The time-domain approach achieves >99% accuracy in a fraction of the time.
Attempting to view my data as Kilosort4 would see it using direct calls to
get_tracesoutside ofTimeSeriesChunkExecutorwas very slow. Profiling showed that the vast majority of the time (98%) was being spent phase shifting, at least on a 384x1M chunk of data. This is perhaps an unrealistic length of data, but even with more standard ~1 second chunks of data (384x30k), the vast majority of the preprocessing time was spent phase shifting. This led me to investigate why it was so slow.The current approach to phase shifting is doing a rFFT of the entire recording chunk, multiplying the fourier-transformed data with a complex exponential, and then converting back into the time domain with irFFT. This is correct, but scales with O(n log n), requires padding to avoid wrap-around, and is very expensive per sample. However, since the signal is band-limited, we can instead use whittaker-shannon interpolation and do smaller convolutions in the time-domain.
A fractional-sample delay is a sinc interpolation at the desired offset. Whittaker–Shannon states that any signal bandlimited to Nyquist is exactly reconstructed from its samples via
so delaying channel$c$ by a fractional $d_c$ samples is
— convolution with the ideal fractional-delay kernel$h_{d_c}[k] = \mathrm{sinc}(k - d_c)$ .
The existing FFT path realises this convolution spectrally: multiplying by$e^{i,2\pi f, d_c}$ in the frequency domain is the DFT of that same infinite sinc kernel.
We can instead do time domain convolution with finite impulse response sinc filter. Truncating the sinc kernel allows massive performance gains while sacraficing minimal amounts of accuracy. In this PR, the FIR path realises the phase shift in the time domain as explicit linear convolution against a Kaiser-windowed, 32-tap truncation of the same sinc. The operation is identical; the only approximations are:
For the 384x1M sample case, changing the interpolation to a FIR filter instead of multiplying the FFT sped up my entire
PhaseShiftRecording → HighpassFilterRecording(300 Hz) → CommonReferenceRecordingpreprocessing pipeline by 5.4x. When combined with the companion PR #4564, this increased to 13.1x speedup. Notably, FIR interpolation uses ~2.8x less peak memory than the FFT approach. However, these benchmarks take advantage of numba-level parallelism that I added to the FIR approach.Isolating only the PhaseShiftRecording component shows that the algorithmic improvement alone produces 10x faster phase shifting. Here I am testing different configurations of
TimeSeriesChunkExecutorwith different levels of 'outer' parallelism (n_jobs) and 'inner' parallelism (numba threads). CRE number is then_jobssetting.Algorithm alone beats best-outer-parallelism-alone (10× vs 4.7×). The algorithmic change breaks through a ceiling that
TimeSeriesChunkExecutoron stock FFT can't — outer parallelism can only distribute the same FFT work across workers, not change the total work done. Algorithm + outer alone already reaches 48.7×; adding inner (numba default vs 1 thread) takes it from 48.7× to 54.1× — only ~10% more. Inner parallelism has diminishing returns once outer saturates cores.Correctness
But how accurate is this truncated 32-tap sinc interpolation? The tests in this PR test exactly that, and on both synthetic and actual neuropixels data the difference between the truncated sinc and the infinite sinc is ~0.2%
< 1%on synthetic< 0.5%on real NP 2.0 datatest_phase_shift(FFT chunked-vs-full identity)error_mean / rms < 0.001All existing PhaseShift tests pass unchanged.
Changes
1.
PhaseShiftRecording(method="fft"|"fir", n_taps=32, output_dtype=None)File:
src/spikeinterface/preprocessing/phase_shift.pymethodkwarg (default"fft"for backward compatibility).method="fir"uses a 32-tap Kaiser-windowed sinc FIR, implemented as numba-jit kernels withprangeover time.n_taps // 2samples (16 for the 32-tap default), not the 40 ms the FFT path needs.n_tapsconfigurable (default 32, validated as even and ≥ 2).int16-native input reader (always on for int16 parents)
When the parent recording's dtype is int16, the FIR path dispatches to an int16-input numba kernel that reads int16 samples directly and accumulates in float32. No explicit input cast — the promotion happens per-element inside the kernel's convolution loop, avoiding a full int16 → float32 buffer materialisation. Active automatically for any int16 parent; no opt-in required.
output_dtype=np.float32— skip the output round-backIndependently, the phase shift output stage can optionally skip its round-to-int16 cast:
output_dtype=None): FIR internally produces float32 samples, then rounds + casts back to the parent dtype (e.g., int16). Preserves the int16 contract for downstream stages that expect it.output_dtype=np.float32: FIR writes float32 directly, andPhaseShiftRecordingadvertises float32 as its output dtype. Downstream stages that inherit (dtype=None) consume float32 and skip their own int16 round-backs; the full pipeline stays in floating-point.dtype=np.int16, they will cast back to int16 regardless of what phase shift advertises, reinstating the round-back at their own output boundary.output_dtype=np.float32is fully effective only when the caller builds a downstream chain that inherits dtype from phase shift (or explicitly setsdtype=np.float32on HP/CMR etc.).2.
apply_raised_cosine_taperextractFile:
src/spikeinterface/core/time_series_tools.pyapply_raised_cosine_taper(data, margin, *, inplace=True)exposes the raised-cosine window that was previously inlined inget_chunk_with_margin(window_on_margin=True).window_on_margin=Truecontinues to work but is deprecated: it emits aDeprecationWarningand delegates toapply_raised_cosine_taper.PhaseShiftRecordingpath is updated to callget_chunk_with_margin(window_on_margin=False)and thenapply_raised_cosine_taperexplicitly. Output is bit-for-bit equivalent to pre-refactor behavior (regression test added).get_chunk_with_marginwas unusable for bounded-support filters both because the taper was redundant and because the in-place*=against a float taper fails on int-typed chunks. Separating the concern makes the utility method-agnostic.Performance (reproducible)
Here are some other relevant benchmarks for this PR.
benchmarks/preprocessing/bench_perf.py— synthetic NumpyRecording, 1M × 384 int16,PS → HP @ 300 Hz → CMRpipeline , measured on a 24-core x86_64 host (SI 0.103 dev, numpy 2.1, scipy 1.14, numba 0.60).End-to-end pipeline — direct
get_traces()(no CRE)Scope: full
PhaseShiftRecording → HighpassFilterRecording → CommonReferenceRecordingpipeline, singleget_traces()call on the whole recording (noTimeSeriesChunkExecutorchunking). Numba threads at default (all cores).This PR alone (FIR phase shift + stock BP/CMR):
Combined with companion
n_workersPR (measured on development branch with both PRs applied):This PR alone delivers the bulk of the direct-
get_traces()win because stock phase shift FFT dominated the pipeline; FIR demotes it from ~68 s to ~1 s, exposing band pass (~8 s) as the new bottleneck, which the companion PR'sn_workerskwargs then unlock. The int16-preserved path pays ~1.5× more time than f32-propagated because each stage round-trips through float internally and casts back.FIR × CRE outer parallelism
Scope: full
PhaseShiftRecording → HighpassFilterRecording → CommonReferenceRecordingpipeline end-to-end, 1M × 384 int16, chunk_duration="1s", stock BP/CMR (no companion-PRn_workerskwargs). Only CRE outer n_jobs and phase shiftmethodvary; numba threads are left at their default (all cores), which is the value FIR uses for its internalprange."int16 preserved" rows use the int16-throughout pipeline (each stage rounds back to int16 at its output boundary); the int16-reading FIR kernel is active internally when input is int16, but the final cast reinstates the int16 contract. "f32 throughout" rows flip every stage's
dtypetonp.float32so intermediate buffers stay in floating-point.FIR stacks cleanly with
TimeSeriesChunkExecutor. Users already running atn_jobs ≈ core_counton stock get ~3× more speedup from the algorithmic change alone (4.98s → 1.62s).Peak RSS scaling (chunk=1s, 1M × 384 int16, thread engine)
Scope: full
PhaseShiftRecording → HighpassFilterRecording → CommonReferenceRecordingpipeline end-to-end. Numba threads follow the Compatibility section's recommendation (numba default on low-n_jobsconfigs where cores are free;NUMBA_NUM_THREADS=1whenn_jobs ≈ core_countto avoid oversubscription):At
chunk=10s, n_jobs=24: stock 13.75 GB, FIR 7.13 GB (1.93× less). The gap widens withn_jobsbecause FIR's numba thread pool is allocated once process-wide; stock's scipy FFT scratch buffers are per-call and don't share across worker threads.get_traces() with entire filter chain
Scope: full
PhaseShiftRecording → HighpassFilterRecording → CommonReferenceRecordingpipeline end-to-end, 1M × 384 int16This PR alone (FIR phase-shift, stock BP/CMR)
get_traces(), int16 preservedget_traces(), f32 propagatedTimeSeriesChunkExecutor(n_jobs=24, thread), int16 preservedCombined with companion PR (adds
n_workerson BP and CMR)get_traces(), int16 preservedget_traces(), f32 propagatedTimeSeriesChunkExecutor(n_jobs=24, thread), int16 preservedCombined numbers require both PRs merged. FIR also uses ~2.8× less peak RSS than FFT at the same parallelism (at
n_jobs=24, chunk=1s: 10.8 GB → 3.9 GB) because the numba thread pool is shared across workers while scipy's FFT scratch buffers aren't.Compatibility
method="fft"is the default; existing callers get existing behavior bit-for-bit.get_chunk_with_margin(window_on_margin=True)still works and emits aDeprecationWarningpointing callers atapply_raised_cosine_taper._kwargsdict updated for new phase shift kwargs;save()/load()round-trip correctly.numbais already a soft dep of SI's Kilosort path; the FIR kernels import it lazily and raise a clear error with install instructions if missing.n_workerskwarg on phase shift. FIR parallelism is internal to the numba kernel (prangeover time), dispatched to numba's process-global thread pool. Tune viaNUMBA_NUM_THREADSenv var ornumba.set_num_threads()— the standard numba mechanism. Suggested settings: numba default (all cores) for directget_traces()callers and forn_jobs=1under CRE;NUMBA_NUM_THREADS=1whenn_jobs ≈ core_countto avoid oversubscription. Not set at library level (follows scipy/sklearn convention).Companion PR
An independent companion PR #4564 adds
n_workerskwargs onFilterRecordingandCommonReferenceRecordingwith per-caller-thread inner pools. Most valuable for directget_traces()callers, where the BP/CMR parallelism compounds on top of this PR's FIR to reach 13× (int16) / 19.5× (f32) pipeline speedup. Under CRE the companion kwargs add a smaller additional gain without causing shared-pool queueing. The two PRs have no code dependency and can land in either order.