Matrix Ordering#

Argmax#

#include <raft/matrix/argmax.cuh>

namespace raft::matrix

template<typename math_t, typename idx_t, typename matrix_idx_t>
void argmax(raft::resources const &handle, raft::device_matrix_view<const math_t, matrix_idx_t, row_major> in, raft::device_vector_view<idx_t, matrix_idx_t> out)#

Argmax: find the col idx with maximum value for each row.

Parameters:
  • handle[in] raft handle

  • in[in] input matrix of size (n_rows, n_cols)

  • out[out] output vector of size n_rows

Argmin#

#include <raft/matrix/argmin.cuh>

namespace raft::matrix

template<typename math_t, typename idx_t, typename matrix_idx_t>
void argmin(raft::resources const &handle, raft::device_matrix_view<const math_t, matrix_idx_t, row_major> in, raft::device_vector_view<idx_t, matrix_idx_t> out)#

Argmin: find the col idx with minimum value for each row.

Parameters:
  • handle[in] raft handle

  • in[in] input matrix of size (n_rows, n_cols)

  • out[out] output vector of size n_rows

Select-K#

#include <raft/matrix/select_k.cuh>

namespace raft::matrix

template<typename T, typename IdxT>
void select_k(raft::resources const &handle, raft::device_matrix_view<const T, int64_t, row_major> in_val, std::optional<raft::device_matrix_view<const IdxT, int64_t, row_major>> in_idx, raft::device_matrix_view<T, int64_t, row_major> out_val, raft::device_matrix_view<IdxT, int64_t, row_major> out_idx, bool select_min, bool sorted = false)#

Select k smallest or largest key/values from each row in the input data.

If you think of the input data in_val as a row-major matrix with len columns and batch_size rows, then this function selects k smallest/largest values in each row and fills in the row-major matrix out_val of size (batch_size, k).

Example usage

using namespace raft;
// get a 2D row-major array of values to search through
auto in_values = {... input device_matrix_view<const float, int64_t, row_major> ...}
// prepare output arrays
auto out_extents = make_extents<int64_t>(in_values.extent(0), k);
auto out_values  = make_device_mdarray<float>(handle, out_extents);
auto out_indices = make_device_mdarray<int64_t>(handle, out_extents);
// search `k` smallest values in each row
matrix::select_k<float, int64_t>(
  handle, in_values, std::nullopt, out_values.view(), out_indices.view(), true);

Template Parameters:
  • T – the type of the keys (what is being compared).

  • IdxT – the index type (what is being selected together with the keys).

Parameters:
  • handle[in] container of reusable resources

  • in_val[in] inputs values [batch_size, len]; these are compared and selected.

  • in_idx[in] optional input payload [batch_size, len]; typically, these are indices of the corresponding in_val. If in_idx is std::nullopt, a contiguous array 0...len-1 is implied.

  • out_val[out] output values [batch_size, k]; the k smallest/largest values from each row of the in_val.

  • out_idx[out] output payload (e.g. indices) [batch_size, k]; the payload selected together with out_val.

  • select_min[in] whether to select k smallest (true) or largest (false) keys.

  • sorted[in] whether to make sure selected pairs are sorted by value

Column-wise Sort#

#include <raft/matrix/col_wise_sort.cuh>

namespace raft::matrix

template<typename in_t, typename out_t, typename matrix_idx_t, typename sorted_keys_t>
void sort_cols_per_row(raft::resources const &handle, raft::device_matrix_view<const in_t, matrix_idx_t, raft::row_major> in, raft::device_matrix_view<out_t, matrix_idx_t, raft::row_major> out, sorted_keys_t &&sorted_keys_opt)#

sort columns within each row of row-major input matrix and return sorted indexes modelled as key-value sort with key being input matrix and value being index of values

Template Parameters:
  • in_t – element type of input matrix

  • out_t – element type of output matrix

  • matrix_idx_t – integer type for matrix indexing

  • sorted_keys_t – std::optional<raft::device_matrix_view<in_t, matrix_idx_t, raft::row_major>> sorted_keys_opt

Parameters:
  • handle[in] raft handle

  • in[in] input matrix

  • out[out] output value(index) matrix

  • sorted_keys_opt[out] std::optional, output matrix for sorted keys (input)

template<typename ...Args, typename = std::enable_if_t<sizeof...(Args) == 3>>
void sort_cols_per_row(Args... args)#

Overload of sort_keys_per_row to help the compiler find the above overload, in case users pass in std::nullopt for one or both of the optional arguments.

Please see above for documentation of sort_keys_per_row.