1-Nearest Neighbors#

#include <raft/distance/fused_l2_nn.cuh>

namespace raft::distance

template<typename DataT, typename OutT, typename IdxT, typename ReduceOpT, typename KVPReduceOpT>
void fusedL2NN(OutT *min, const DataT *x, const DataT *y, const DataT *xn, const DataT *yn, IdxT m, IdxT n, IdxT k, void *workspace, ReduceOpT redOp, KVPReduceOpT pairRedOp, bool sqrt, bool initOutBuffer, cudaStream_t stream)#

Fused L2 distance and 1-nearest-neighbor computation in a single call.

The benefits of such a call are 2-fold: 1) eliminate the need for an intermediate buffer to store the output of gemm 2) reduce the memory read traffic on this intermediate buffer, otherwise needed during the reduction phase for 1-NN.

Template Parameters:
  • DataT – data type

  • OutT – output type to either store 1-NN indices and their minimum distances or store only the min distances. Accordingly, one has to pass an appropriate ReduceOpT

  • IdxT – indexing arithmetic type

  • ReduceOpT – A struct to perform the final needed reduction operation and also to initialize the output array elements with the appropriate initial value needed for reduction.

Parameters:
  • min[out] will contain the reduced output (Length = m) (on device)

  • x[in] first matrix. Row major. Dim = m x k. (on device).

  • y[in] second matrix. Row major. Dim = n x k. (on device).

  • xn[in] L2 squared norm of x. Length = m. (on device).

  • yn[in] L2 squared norm of y. Length = n. (on device)

  • m[in] gemm m

  • n[in] gemm n

  • k[in] gemm k

  • workspace[in] temp workspace. Size = sizeof(int)*m. (on device)

  • redOp[in] reduction operator in the epilogue

  • pairRedOp[in] reduction operation on key value pairs

  • sqrt[in] Whether the output minDist should contain L2-sqrt

  • initOutBuffer[in] whether to initialize the output buffer before the main kernel launch

  • stream[in] cuda stream

template<typename DataT, typename OutT, typename IdxT>
void fusedL2NNMinReduce(OutT *min, const DataT *x, const DataT *y, const DataT *xn, const DataT *yn, IdxT m, IdxT n, IdxT k, void *workspace, bool sqrt, bool initOutBuffer, cudaStream_t stream)#

Wrapper around fusedL2NN with minimum reduction operators.

fusedL2NN cannot be compiled in the distance library due to the lambda operators, so this wrapper covers the most common case (minimum). This should be preferred to the more generic API when possible, in order to reduce compilation times for users of the shared library.

Template Parameters:
  • DataT – data type

  • OutT – output type to either store 1-NN indices and their minimum distances (e.g. raft::KeyValuePair<int, float>) or store only the min distances.

  • IdxT – indexing arithmetic type

Parameters:
  • min[out] will contain the reduced output (Length = m) (on device)

  • x[in] first matrix. Row major. Dim = m x k. (on device).

  • y[in] second matrix. Row major. Dim = n x k. (on device).

  • xn[in] L2 squared norm of x. Length = m. (on device).

  • yn[in] L2 squared norm of y. Length = n. (on device)

  • m[in] gemm m

  • n[in] gemm n

  • k[in] gemm k

  • workspace[in] temp workspace. Size = sizeof(int)*m. (on device)

  • sqrt[in] Whether the output minDist should contain L2-sqrt

  • initOutBuffer[in] whether to initialize the output buffer before the main kernel launch

  • stream[in] cuda stream

template<typename LabelT, typename DataT>
using KVPMinReduce = detail::KVPMinReduceImpl<LabelT, DataT>#
template<typename LabelT, typename DataT>
using MinAndDistanceReduceOp = detail::MinAndDistanceReduceOpImpl<LabelT, DataT>#
template<typename LabelT, typename DataT>
using MinReduceOp = detail::MinReduceOpImpl<LabelT, DataT>#

#include <raft/distance/fused_l2_nn.cuh> namespace raft::distance

template<typename DataT, typename OutT, typename IdxT, typename ReduceOpT, typename KVPReduceOpT>
void masked_l2_nn(raft::resources const &handle, raft::distance::masked_l2_nn_params<ReduceOpT, KVPReduceOpT> params, raft::device_matrix_view<const DataT, IdxT, raft::layout_c_contiguous> x, raft::device_matrix_view<const DataT, IdxT, raft::layout_c_contiguous> y, raft::device_vector_view<const DataT, IdxT, raft::layout_c_contiguous> x_norm, raft::device_vector_view<const DataT, IdxT, raft::layout_c_contiguous> y_norm, raft::device_matrix_view<const bool, IdxT, raft::layout_c_contiguous> adj, raft::device_vector_view<const IdxT, IdxT, raft::layout_c_contiguous> group_idxs, raft::device_vector_view<OutT, IdxT, raft::layout_c_contiguous> out)#

Masked L2 distance and 1-nearest-neighbor computation in a single call.

This function enables faster computation of nearest neighbors if the computation of distances between certain point pairs can be skipped.

We use an adjacency matrix that describes which distances to calculate. The points in y are divided into groups, and the adjacency matrix indicates whether to compute distances between points in x and groups in y. In other words, if adj[i,k] is true then distance between point x_i, and points in group_k will be calculated.

Performance considerations

The points in x are processed in tiles of M points (M is currently 64, but may change in the future). As a result, the largest compute time reduction occurs if all M points can skip a group. If only part of the M points can skip a group, then at most a minor compute time reduction and a modest energy use reduction can be expected.

The points in y are also grouped into tiles of N points (N is currently 64, but may change in the future). As a result, group sizes should be larger than N to avoid wasting computational resources. If the group sizes are evenly divisible by N, then the computation is most efficient, although for larger group sizes this effect is minor.

Comparison to SDDM

SDDMM (sampled dense-dense matrix multiplication) is a matrix-matrix multiplication where only part of the output is computed. Compared to masked_l2_nn, there are a few differences:

  • The output of masked_l2_nn is a single vector (of nearest neighbors) and not a sparse matrix.

  • The sampling in masked_l2_nn is expressed through intermediate “groups” rather than a CSR format.

Template Parameters:
  • DataT – data type

  • OutT – output type to either store 1-NN indices and their minimum distances or store only the min distances. Accordingly, one has to pass an appropriate ReduceOpT

  • IdxT – indexing arithmetic type

  • ReduceOpT – A struct to perform the final needed reduction operation and also to initialize the output array elements with the appropriate initial value needed for reduction.

Parameters:
  • handle – RAFT handle for managing expensive resources

  • params – Parameter struct specifying the reduction operations.

  • x[in] First matrix. Row major. Dim = m x k. (on device).

  • y[in] Second matrix. Row major. Dim = n x k. (on device).

  • x_norm[in] L2 squared norm of x. Length = m. (on device).

  • y_norm[in] L2 squared norm of y. Length = n. (on device)

  • adj[in] A boolean adjacency matrix indicating for each row of x and each group in y whether to compute the distance. Dim = m x num_groups.

  • group_idxs[in] An array containing the end indices of each group in y. The value of group_idxs[j] indicates the start of group j + 1, i.e., it is the inclusive scan of the group lengths. The first group is always assumed to start at index 0 and the last group typically ends at index n. Length = num_groups.

  • out[out] will contain the reduced output (Length = m) (on device)

template<typename ReduceOpT, typename KVPReduceOpT>
struct masked_l2_nn_params#

Parameter struct for masked_l2_nn function.

Usage example:

#include <raft/distance/masked_nn.cuh>

using IdxT        = int;
using DataT       = float;
using RedOpT      = raft::distance::MinAndDistanceReduceOp<IdxT, DataT>;
using PairRedOpT  = raft::distance::KVPMinReduce<IdxT, DataT>;
using ParamT      = raft::distance::masked_l2_nn_params<RedOpT, PairRedOpT>;

bool init_out = true;
bool sqrt     = false;

ParamT masked_l2_params{RedOpT{}, PairRedOpT{}, sqrt, init_out};

Prescribes how to reduce a distance to an intermediate type (redOp), and how to reduce two intermediate types (pairRedOp). Typically, a distance is mapped to an (index, value) pair and (index, value) pair with the lowest value (distance) is selected.

In addition, prescribes whether to compute the square root of the distance (sqrt) and whether to initialize the output buffer (initOutBuffer).

Template Parameters:
  • ReduceOpT – Type of reduction operator in the epilogue.

  • KVPReduceOpT – Type of Reduction operation on key value pairs.

Public Members

ReduceOpT redOp#

Reduction operator in the epilogue

KVPReduceOpT pairRedOp#

Reduction operation on key value pairs

bool sqrt#

Whether the output minDist should contain L2-sqrt

bool initOutBuffer#

Whether to initialize the output buffer before the main kernel launch