Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 206 additions & 76 deletions gigl-core/core/sampling/ppr_forward_push.cpp

Large diffs are not rendered by default.

15 changes: 13 additions & 2 deletions gigl-core/core/sampling/ppr_forward_push.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,12 @@ class PPRForwardPush {
std::optional<std::unordered_map<int32_t, torch::Tensor>> drainQueue();

// Push residuals given fetched neighbor data.
// fetchedByEtypeId: {etype_id: (node_ids[N], flat_nbrs[sum(counts)], counts[N])}
// fetchedByEtypeId: {etype_id: (node_ids[N], flat_nbrs[sum(counts)], counts[N], flat_weights[sum(counts)])}
// flat_weights is empty (numel()==0) for uniform-residual mode; non-empty for
// weight-proportional mode. _hasWeights is latched true on the first call with a
// non-empty flat_weights and never reset within one PPRForwardPush lifetime.
void pushResiduals(const std::unordered_map<
int32_t, std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>&
int32_t, std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>&
fetchedByEtypeId);

// Return top-k PPR nodes per seed per node type.
Expand Down Expand Up @@ -103,6 +106,14 @@ class PPRForwardPush {
// impractical (contrast with _state above). Populated incrementally; avoids re-fetching.
std::unordered_map<uint64_t, std::vector<int32_t>> _neighborCache;

// True once any pushResiduals call receives a non-empty flat_weights tensor.
// Latched true for the object lifetime; never reset.
bool _hasWeights{false};

// Per-edge weights parallel to _neighborCache: _weightCache[packKey(node, etype)][i]
// is the weight of the i-th cached neighbor. Only populated in weighted mode.
std::unordered_map<uint64_t, std::vector<double>> _weightCache;

};

} // namespace gigl
10 changes: 8 additions & 2 deletions gigl-core/core/sampling/python_ppr_forward_push.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,21 @@ namespace gigl {
// pushResiduals: a wrapper is needed solely to release the GIL during the C++ push.
// pybind11/stl.h handles all type conversions automatically; the other methods use
// direct member function pointers for the same reason.
//
// Each tuple value is (node_ids, flat_nbrs, counts, flat_weights). flat_weights is
// an empty tensor in uniform-residual mode and a non-empty float64 tensor in
// weight-proportional mode.
static void pushResidualsWrapper(PPRForwardPush& state, const py::dict& fetchedByEtypeId) {
std::unordered_map<int32_t, std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>> neighborTensorsByEtypeId;
std::unordered_map<int32_t, std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>
neighborTensorsByEtypeId;
// Dict iteration touches Python objects — GIL must be held here.
for (auto item : fetchedByEtypeId) {
auto edgeTypeId = item.first.cast<int32_t>();
auto neighborTensors = item.second.cast<py::tuple>();
neighborTensorsByEtypeId[edgeTypeId] = {neighborTensors[0].cast<torch::Tensor>(),
neighborTensors[1].cast<torch::Tensor>(),
neighborTensors[2].cast<torch::Tensor>()};
neighborTensors[2].cast<torch::Tensor>(),
neighborTensors[3].cast<torch::Tensor>()};
}
// C++ push only uses tensor accessor/data_ptr APIs — GIL-safe to release.
// Releasing here lets the asyncio event loop process RPC completion callbacks
Expand Down
4 changes: 3 additions & 1 deletion gigl-core/src/gigl_core/ppr_forward_push.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ class PPRForwardPush:
def drain_queue(self) -> dict[int, torch.Tensor] | None: ...
def push_residuals(
self,
fetched_by_etype_id: dict[int, tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
fetched_by_etype_id: dict[
int, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
],
) -> None: ...
def extract_top_k(
self, max_ppr_nodes: int
Expand Down
48 changes: 44 additions & 4 deletions gigl-core/tests/ppr_forward_push_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

using gigl::PPRForwardPush;

using FetchedByEtypeId = std::unordered_map<int32_t, std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>;

// Builds a single-edge-type, single-node-type PPRForwardPush.
static PPRForwardPush makeState(
const std::vector<int64_t>& seeds,
Expand All @@ -20,16 +22,21 @@ static PPRForwardPush makeState(
}

// Convenience wrapper: build the fetchedByEtypeId argument for pushResiduals
// from flat vectors, keeping test call sites readable.
static std::unordered_map<int32_t, std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>
// from flat vectors, keeping test call sites readable. Empty weights select
// uniform-residual mode.
static FetchedByEtypeId
makeFetched(int32_t edgeTypeId,
const std::vector<int64_t>& nodeIds,
const std::vector<int64_t>& flatNeighborIds,
const std::vector<int64_t>& counts) {
const std::vector<int64_t>& counts,
const std::vector<double>& flatWeights = {}) {
auto weightsTensor =
flatWeights.empty() ? torch::empty({0}, torch::kDouble) : torch::tensor(flatWeights, torch::kDouble);
return {{edgeTypeId,
{torch::tensor(nodeIds, torch::kLong),
torch::tensor(flatNeighborIds, torch::kLong),
torch::tensor(counts, torch::kLong)}}};
torch::tensor(counts, torch::kLong),
weightsTensor}}};
}

// After construction, drainQueue() returns the seed node under etype 0.
Expand Down Expand Up @@ -89,6 +96,39 @@ TEST(PPRForwardPush, ResidualDistributedToNeighbor) {
EXPECT_NEAR(weights[1].item<float>(), static_cast<float>((1.0 - alpha) * alpha), 1e-5F);
}

// In weighted mode, zero-weight edges must not enqueue a zero-residual neighbor.
TEST(PPRForwardPush, WeightedResidualSkipsZeroWeightNeighbor) {
const double alpha = 0.5;
auto state = makeState(/*seeds=*/{0}, alpha, /*requeueThresholdFactor=*/1e-6, /*degrees=*/{2, 0, 0});

state.drainQueue();
state.pushResiduals(makeFetched(
/*edgeTypeId=*/0,
/*nodeIds=*/{0},
/*flatNeighborIds=*/{1, 2},
/*counts=*/{2},
/*flatWeights=*/{0.0, 2.0}));

auto iter2 = state.drainQueue();
ASSERT_TRUE(iter2.has_value());
const auto& iter2Map = iter2.value();
ASSERT_NE(iter2Map.find(0), iter2Map.end());
ASSERT_EQ(iter2Map.at(0).size(0), 1);
EXPECT_EQ(iter2Map.at(0)[0].item<int64_t>(), 2);

state.pushResiduals({});
EXPECT_FALSE(state.drainQueue().has_value());

auto topk = state.extractTopK(10);
ASSERT_NE(topk.find(0), topk.end());
const auto& [ids, weights, counts] = topk.at(0);
ASSERT_EQ(counts[0].item<int64_t>(), 2);
EXPECT_EQ(ids[0].item<int64_t>(), 0);
EXPECT_EQ(ids[1].item<int64_t>(), 2);
EXPECT_NEAR(weights[0].item<float>(), static_cast<float>(alpha), 1e-5F);
EXPECT_NEAR(weights[1].item<float>(), static_cast<float>((1.0 - alpha) * alpha), 1e-5F);
}

// Two seeds (0 and 1) both push residual to sink node 2. The neighbor-lookup
// request must deduplicate to one entry for node 2, yet both seeds must still
// accumulate a PPR score for it.
Expand Down
7 changes: 0 additions & 7 deletions gigl/distributed/base_dist_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,6 @@ def validate_for_weighted_sampling(

Raises:
ValueError: If ``with_weight=True`` but no edge weights are registered.
NotImplementedError: If ``with_weight=True`` and a PPR sampler is requested.
"""
if not with_weight:
return
Expand All @@ -362,12 +361,6 @@ def validate_for_weighted_sampling(
"with_weight=True requires edge weights to be registered in the dataset. "
"Pass weight_edge_feat_name to build_dataset() to register edge weights."
)
# TODO(mkolodner-sc): Implement weight-proportional residual propagation for PPR.
if with_weight and isinstance(sampler_options, PPRSamplerOptions):
raise NotImplementedError(
"Weighted sampling is not yet supported with PPRSamplerOptions. "
"Weight-proportional residual propagation for PPR is planned but not implemented."
)

@staticmethod
def create_sampling_config(
Expand Down
Loading