Skip to content

render_points materialises dask DataFrame twice when color= is a points column #633

@timtreis

Description

@timtreis

render_points materialises dask DataFrame twice when color= is a points column

Environment: spatialdata-plot 0.3.4.dev (main, commit 5cfedc7), Python 3.13


Problem

When render_points is called with color=<col> where the column is stored directly in the points element (not in an AnnData table), the dask DataFrame is materialised twice:

  1. render.py:747: points = points[coords].compute() — loads the data into a pandas DataFrame
  2. utils.py:1091 (inside _set_color_source_vecget_values): a second .compute() call on the original dask DataFrame via get_values(value_key=col_for_color, sdata=sdata, element_name=element, ...)

The root cause: the preloaded_color_data optimisation (introduced to avoid redundant get_values calls) is only activated when color data came from an AnnData table (added_color_from_table = True). When the color column is native to the points element, added_color_from_table stays False, _preloaded is set to None, and _set_color_source_vec falls back to a fresh get_values() call.

For large point clouds (Xenium: ~100M cells, MERFISH: ~10M cells), each .compute() reads the full Parquet file from disk. A redundant second materialisation doubles I/O and doubles peak memory usage.


Minimal reproducible example

import dask; dask.config.set({"dataframe.query-planning": False})
import numpy as np, pandas as pd
import dask.dataframe as dd
import spatialdata as sd
from spatialdata.models import PointsModel
import matplotlib; matplotlib.use("Agg")
import matplotlib.pyplot as plt
import spatialdata_plot

# Patch .compute() to count calls
compute_calls = []
original_compute = dd.DataFrame.compute
def counting_compute(self, **kwargs):
    import traceback
    compute_calls.append(traceback.extract_stack())
    return original_compute(self, **kwargs)
dd.DataFrame.compute = counting_compute

rng = np.random.default_rng(42)
n = 500
df = pd.DataFrame({
    "x": rng.uniform(0, 100, n),
    "y": rng.uniform(0, 100, n),
    "cell_type": pd.Categorical(rng.choice(["A", "B", "C"], n)),
})
points = PointsModel.parse(df)
sdata = sd.SpatialData(points={"pts": points})

fig, ax = plt.subplots()
sdata.pl.render_points("pts", color="cell_type").pl.show(ax=ax, show=False)
plt.close()

print(f"DataFrame.compute() calls: {len(compute_calls)}")
# Prints: 2  (should be 1)

Expected behaviour

DataFrame.compute() called once.

Actual behaviour

DataFrame.compute() called twice:

  1. render.py:747 — explicit .compute() to load points
  2. utils.py:1091 — redundant get_values() call on the original dask frame

Fix sketch

In _render_points, extend the _preloaded logic to cover element-native color columns (not just table-sourced columns):

# Current (only covers table case):
_preloaded = (
    points_pd_with_color[col_for_color]
    if added_color_from_table and col_for_color is not None
    else None
)

# Fixed: also provide preloaded data when col was already in the loaded DataFrame:
_preloaded = (
    points_pd_with_color[col_for_color]
    if col_for_color is not None and col_for_color in points_pd_with_color.columns
    else None
)

This makes the preloaded_color_data path active for both table-sourced and element-native color columns, eliminating the redundant get_values() call.


Triage tier: Tier 3

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions