// SPDX-FileCopyrightText: Copyright (c) 2011, Duane Merrill. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c) 2011-2025, NVIDIA CORPORATION. All rights reserved.
// SPDX-License-Identifier: BSD-3

/**
 * @file
 * Simple binary operator functor types
 */

/******************************************************************************
 * Simple functor operators
 ******************************************************************************/

#pragma once

#include <cub/config.cuh>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <cub/util_type.cuh>

#include <cuda/__functional/maximum.h>
#include <cuda/__functional/minimum.h>
#include <cuda/std/__functional/operations.h>
#include <cuda/std/__utility/integer_sequence.h>
#include <cuda/std/__utility/pair.h>
#include <cuda/std/cstdint>
#include <cuda/std/limits>

CUB_NAMESPACE_BEGIN

// TODO(bgruber): deprecate in C++17 with a note: "replace by decltype(cuda::std::not_fn(EqualityOp{}))"
/// @brief Inequality functor (wraps equality functor)
template <typename EqualityOp>
struct InequalityWrapper
{
  /// Wrapped equality operator
  EqualityOp op;

  /// Constructor
  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE InequalityWrapper(EqualityOp op)
      : op(op)
  {}

  /// Boolean inequality operator, returns `t != u`
  template <typename T, typename U>
  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE bool operator()(T&& t, U&& u)
  {
    return !op(::cuda::std::forward<T>(t), ::cuda::std::forward<U>(u));
  }
};

/// @brief Arg max functor (keeps the value and offset of the first occurrence
///        of the larger item)
struct ArgMax
{
  /// Boolean max operator, preferring the item having the smaller offset in
  /// case of ties
  template <typename T, typename OffsetT>
  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE KeyValuePair<OffsetT, T>
  operator()(const KeyValuePair<OffsetT, T>& a, const KeyValuePair<OffsetT, T>& b) const
  {
    // Mooch BUG (device reduce argmax gk110 3.2 million random fp32)
    // return ((b.value > a.value) ||
    //         ((a.value == b.value) && (b.key < a.key)))
    //      ? b : a;

    if ((b.value > a.value) || ((a.value == b.value) && (b.key < a.key)))
    {
      return b;
    }

    return a;
  }
};

/// @brief Arg min functor (keeps the value and offset of the first occurrence
///        of the smallest item)
struct ArgMin
{
  /// Boolean min operator, preferring the item having the smaller offset in
  /// case of ties
  template <typename T, typename OffsetT>
  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE KeyValuePair<OffsetT, T>
  operator()(const KeyValuePair<OffsetT, T>& a, const KeyValuePair<OffsetT, T>& b) const
  {
    // Mooch BUG (device reduce argmax gk110 3.2 million random fp32)
    // return ((b.value < a.value) ||
    //         ((a.value == b.value) && (b.key < a.key)))
    //      ? b : a;

    if ((b.value < a.value) || ((a.value == b.value) && (b.key < a.key)))
    {
      return b;
    }

    return a;
  }
};

namespace detail
{
/// @brief Arg max functor (keeps the value and offset of the first occurrence
///        of the larger item)
struct arg_max
{
  /// Boolean max operator, preferring the item having the smaller offset in
  /// case of ties
  template <typename T, typename OffsetT>
  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE ::cuda::std::pair<OffsetT, T>
  operator()(const ::cuda::std::pair<OffsetT, T>& a, const ::cuda::std::pair<OffsetT, T>& b) const
  {
    if ((b.second > a.second) || ((a.second == b.second) && (b.first < a.first)))
    {
      return b;
    }

    return a;
  }
};

/// @brief Arg min functor (keeps the value and offset of the first occurrence
///        of the smallest item)
struct arg_min
{
  /// Boolean min operator, preferring the item having the smaller offset in
  /// case of ties
  template <typename T, typename OffsetT>
  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE ::cuda::std::pair<OffsetT, T>
  operator()(const ::cuda::std::pair<OffsetT, T>& a, const ::cuda::std::pair<OffsetT, T>& b) const
  {
    if ((b.second < a.second) || ((a.second == b.second) && (b.first < a.first)))
    {
      return b;
    }

    return a;
  }
};

template <typename ScanOpT>
struct ScanBySegmentOp
{
  /// Wrapped operator
  ScanOpT op;

  /// Constructor
  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE ScanBySegmentOp() {}

  /// Constructor
  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE ScanBySegmentOp(ScanOpT op)
      : op(op)
  {}

  /**
   * @brief Scan operator
   *
   * @tparam KeyValuePairT
   *   KeyValuePair pairing of T (value) and int (head flag)
   *
   * @param[in] first
   *   First partial reduction
   *
   * @param[in] second
   *   Second partial reduction
   */
  template <typename KeyValuePairT>
  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE KeyValuePairT operator()(const KeyValuePairT& first, const KeyValuePairT& second)
  {
    KeyValuePairT retval;
    retval.key = first.key | second.key;
#ifdef _NVHPC_CUDA // WAR bug on nvc++
    if (second.key)
    {
      retval.value = second.value;
    }
    else
    {
      // If second.value isn't copied into a temporary here, nvc++ will
      // crash while compiling the TestScanByKeyWithLargeTypes test in
      // thrust/testing/scan_by_key.cu:
      auto v2      = second.value;
      retval.value = op(first.value, v2);
    }
#else // not nvc++:
    // if (second.key) {
    //   The second partial reduction spans a segment reset, so it's value
    //   aggregate becomes the running aggregate
    // else {
    //   The second partial reduction does not span a reset, so accumulate both
    //   into the running aggregate
    // }
    retval.value = (second.key) ? second.value : op(first.value, second.value);
#endif
    return retval;
  }
};

template <class OpT>
struct basic_binary_op_t
{
  static constexpr bool value = false;
};

template <typename T>
struct basic_binary_op_t<::cuda::std::plus<T>>
{
  static constexpr bool value = true;
};

template <typename T>
struct basic_binary_op_t<::cuda::minimum<T>>
{
  static constexpr bool value = true;
};

template <typename T>
struct basic_binary_op_t<::cuda::maximum<T>>
{
  static constexpr bool value = true;
};
} // namespace detail

/// @brief Default cast functor
template <typename B>
struct CastOp
{
  /// Cast operator, returns `(B) a`
  template <typename A>
  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE B operator()(A&& a) const
  {
    return (B) a;
  }
};

/// @brief Binary operator wrapper for switching non-commutative scan arguments
template <typename ScanOp>
class SwizzleScanOp
{
private:
  /// Wrapped scan operator
  ScanOp scan_op;

public:
  /// Constructor
  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE SwizzleScanOp(ScanOp scan_op)
      : scan_op(scan_op)
  {}

  /// Switch the scan arguments
  template <typename T>
  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE T operator()(const T& a, const T& b)
  {
    T _a(a);
    T _b(b);

    return scan_op(_b, _a);
  }
};

/**
 * @brief Reduce-by-segment functor.
 *
 * Given two cub::KeyValuePair inputs `a` and `b` and a binary associative
 * combining operator `f(const T &x, const T &y)`, an instance of this functor
 * returns a cub::KeyValuePair whose `key` field is `a.key + b.key`, and whose
 * `value` field is either `b.value` if `b.key` is non-zero, or
 * `f(a.value, b.value)` otherwise.
 *
 * ReduceBySegmentOp is an associative, non-commutative binary combining
 * operator for input sequences of cub::KeyValuePair pairings. Such sequences
 * are typically used to represent a segmented set of values to be reduced
 * and a corresponding set of {0,1}-valued integer "head flags" demarcating the
 * first value of each segment.
 *
 * @tparam ReductionOpT Binary reduction operator to apply to values
 */
template <typename ReductionOpT>
struct ReduceBySegmentOp
{
  /// Wrapped reduction operator
  ReductionOpT op;

  /// Constructor
  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE ReduceBySegmentOp() {}

  /// Constructor
  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE ReduceBySegmentOp(ReductionOpT op)
      : op(op)
  {}

  /**
   * @brief Scan operator
   *
   * @tparam KeyValuePairT
   *   KeyValuePair pairing of T (value) and OffsetT (head flag)
   *
   * @param[in] first
   *   First partial reduction
   *
   * @param[in] second
   *   Second partial reduction
   */
  template <typename KeyValuePairT>
  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE KeyValuePairT operator()(const KeyValuePairT& first, const KeyValuePairT& second)
  {
    KeyValuePairT retval;
    retval.key = first.key + second.key;
#ifdef _NVHPC_CUDA // WAR bug on nvc++
    if (second.key)
    {
      retval.value = second.value;
    }
    else
    {
      // If second.value isn't copied into a temporary here, nvc++ will
      // crash while compiling the TestScanByKeyWithLargeTypes test in
      // thrust/testing/scan_by_key.cu:
      auto v2      = second.value;
      retval.value = op(first.value, v2);
    }
#else // not nvc++:
    // if (second.key) {
    //   The second partial reduction spans a segment reset, so it's value
    //   aggregate becomes the running aggregate
    // else {
    //   The second partial reduction does not span a reset, so accumulate both
    //   into the running aggregate
    // }
    retval.value = (second.key) ? second.value : op(first.value, second.value);
#endif
    return retval;
  }
};

/**
 * @tparam ReductionOpT Binary reduction operator to apply to values
 */
template <typename ReductionOpT>
struct ReduceByKeyOp
{
  /// Wrapped reduction operator
  ReductionOpT op;

  /// Constructor
  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE ReduceByKeyOp() {}

  /// Constructor
  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE ReduceByKeyOp(ReductionOpT op)
      : op(op)
  {}

  /**
   * @brief Scan operator
   *
   * @param[in] first First partial reduction
   * @param[in] second Second partial reduction
   */
  template <typename KeyValuePairT>
  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE KeyValuePairT operator()(const KeyValuePairT& first, const KeyValuePairT& second)
  {
    KeyValuePairT retval = second;

    if (first.key == second.key)
    {
      retval.value = op(first.value, retval.value);
    }

    return retval;
  }
};

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document

//----------------------------------------------------------------------------------------------------------------------
// Predefined operators

namespace detail
{
//----------------------------------------------------------------------------------------------------------------------
// Predefined operators

template <typename, typename = void>
inline constexpr bool is_cuda_std_plus_v = false;

template <typename T>
inline constexpr bool is_cuda_std_plus_v<::cuda::std::plus<T>, void> = true;

template <typename T>
inline constexpr bool is_cuda_std_plus_v<::cuda::std::plus<T>, T> = true;

template <typename T>
inline constexpr bool is_cuda_std_plus_v<::cuda::std::plus<>, T> = true;

template <>
inline constexpr bool is_cuda_std_plus_v<::cuda::std::plus<>, void> = true;

template <typename, typename = void>
inline constexpr bool is_cuda_std_mul_v = false;

template <typename T>
inline constexpr bool is_cuda_std_mul_v<::cuda::std::multiplies<T>, void> = true;

template <typename T>
inline constexpr bool is_cuda_std_mul_v<::cuda::std::multiplies<T>, T> = true;

template <typename T>
inline constexpr bool is_cuda_std_mul_v<::cuda::std::multiplies<>, T> = true;

template <>
inline constexpr bool is_cuda_std_mul_v<::cuda::std::multiplies<>, void> = true;

template <typename, typename = void>
inline constexpr bool is_cuda_maximum_v = false;

template <typename T>
inline constexpr bool is_cuda_maximum_v<::cuda::maximum<T>, void> = true;

template <typename T>
inline constexpr bool is_cuda_maximum_v<::cuda::maximum<T>, T> = true;

template <typename T>
inline constexpr bool is_cuda_maximum_v<::cuda::maximum<>, T> = true;

template <>
inline constexpr bool is_cuda_maximum_v<::cuda::maximum<>, void> = true;

template <typename, typename = void>
inline constexpr bool is_cuda_minimum_v = false;

template <typename T>
inline constexpr bool is_cuda_minimum_v<::cuda::minimum<T>, void> = true;

template <typename T>
inline constexpr bool is_cuda_minimum_v<::cuda::minimum<T>, T> = true;

template <typename T>
inline constexpr bool is_cuda_minimum_v<::cuda::minimum<>, T> = true;

template <>
inline constexpr bool is_cuda_minimum_v<::cuda::minimum<>, void> = true;

template <typename, typename = void>
inline constexpr bool is_cuda_std_bit_and_v = false;

template <typename T>
inline constexpr bool is_cuda_std_bit_and_v<::cuda::std::bit_and<T>, void> = true;

template <typename T>
inline constexpr bool is_cuda_std_bit_and_v<::cuda::std::bit_and<T>, T> = true;

template <typename T>
inline constexpr bool is_cuda_std_bit_and_v<::cuda::std::bit_and<>, T> = true;

template <>
inline constexpr bool is_cuda_std_bit_and_v<::cuda::std::bit_and<>, void> = true;

template <typename, typename = void>
inline constexpr bool is_cuda_std_bit_or_v = false;

template <typename T>
inline constexpr bool is_cuda_std_bit_or_v<::cuda::std::bit_or<T>, void> = true;

template <typename T>
inline constexpr bool is_cuda_std_bit_or_v<::cuda::std::bit_or<T>, T> = true;

template <typename T>
inline constexpr bool is_cuda_std_bit_or_v<::cuda::std::bit_or<>, T> = true;

template <>
inline constexpr bool is_cuda_std_bit_or_v<::cuda::std::bit_or<>, void> = true;

template <typename, typename = void>
inline constexpr bool is_cuda_std_bit_xor_v = false;

template <typename T>
inline constexpr bool is_cuda_std_bit_xor_v<::cuda::std::bit_xor<T>, void> = true;

template <typename T>
inline constexpr bool is_cuda_std_bit_xor_v<::cuda::std::bit_xor<T>, T> = true;

template <typename T>
inline constexpr bool is_cuda_std_bit_xor_v<::cuda::std::bit_xor<>, T> = true;

template <>
inline constexpr bool is_cuda_std_bit_xor_v<::cuda::std::bit_xor<>, void> = true;

template <typename, typename = void>
inline constexpr bool is_cuda_std_logical_and_v = false;

template <>
inline constexpr bool is_cuda_std_logical_and_v<::cuda::std::logical_and<bool>, void> = true;

template <>
inline constexpr bool is_cuda_std_logical_and_v<::cuda::std::logical_and<bool>, bool> = true;

template <>
inline constexpr bool is_cuda_std_logical_and_v<::cuda::std::logical_and<>, bool> = true;

template <>
inline constexpr bool is_cuda_std_logical_and_v<::cuda::std::logical_and<>, void> = true;

template <typename, typename = void>
inline constexpr bool is_cuda_std_logical_or_v = false;

template <>
inline constexpr bool is_cuda_std_logical_or_v<::cuda::std::logical_or<bool>, void> = true;

template <>
inline constexpr bool is_cuda_std_logical_or_v<::cuda::std::logical_or<bool>, bool> = true;

template <>
inline constexpr bool is_cuda_std_logical_or_v<::cuda::std::logical_or<>, bool> = true;

template <>
inline constexpr bool is_cuda_std_logical_or_v<::cuda::std::logical_or<>, void> = true;

template <typename Op, typename T = void>
inline constexpr bool is_cuda_minimum_maximum_v = is_cuda_maximum_v<Op, T> || is_cuda_minimum_v<Op, T>;

template <typename Op, typename T = void>
inline constexpr bool is_cuda_std_plus_mul_v = is_cuda_std_plus_v<Op, T> || is_cuda_std_mul_v<Op, T>;

template <typename Op, typename T = void>
inline constexpr bool is_cuda_std_bitwise_v =
  is_cuda_std_bit_and_v<Op, T> || is_cuda_std_bit_or_v<Op, T> || is_cuda_std_bit_xor_v<Op, T>;

template <typename Op, typename T = void>
inline constexpr bool is_cuda_std_logical_v = is_cuda_std_logical_and_v<Op, T> || is_cuda_std_logical_or_v<Op, T>;

template <typename Op, typename T = void>
inline constexpr bool is_simd_enabled_cuda_operator =
  is_cuda_minimum_maximum_v<Op, T> || //
  is_cuda_std_plus_mul_v<Op, T> || //
  is_cuda_std_bitwise_v<Op, T>;

template <typename Op, typename T = void>
inline constexpr bool is_cuda_binary_operator =
  is_cuda_minimum_maximum_v<Op, T> || //
  is_cuda_std_plus_mul_v<Op, T> || //
  is_cuda_std_bitwise_v<Op, T> || //
  is_cuda_std_logical_v<Op, T>;

//----------------------------------------------------------------------------------------------------------------------
// Generalize Operator

template <typename Operator>
struct GeneralizeOperator
{
  using type = Operator;
};

template <template <typename = void> class Operator, typename T>
struct GeneralizeOperator<Operator<T>>
{
  using type = Operator<>;
};

template <typename Op>
using generalize_operator_t = typename GeneralizeOperator<Op>::type;

template <typename Operator>
[[nodiscard]] constexpr _CCCL_DEVICE _CCCL_FORCEINLINE auto generalize_operator(Operator op)
{
  if constexpr (is_cuda_std_logical_or_v<Operator> || is_cuda_std_logical_and_v<Operator>
                || is_cuda_minimum_maximum_v<Operator> || is_cuda_std_plus_mul_v<Operator>
                || is_cuda_std_bitwise_v<Operator>)
  {
    return generalize_operator_t<Operator>{};
  }
  else
  {
    return op;
  }
}

//----------------------------------------------------------------------------------------------------------------------
// Identity

template <typename Op, typename T = void>
inline constexpr T identity_v;

template <typename T>
inline constexpr T identity_v<::cuda::minimum<>, T> = ::cuda::std::numeric_limits<T>::max();

template <typename T>
inline constexpr T identity_v<::cuda::minimum<T>, T> = ::cuda::std::numeric_limits<T>::max();

template <typename T>
inline constexpr T identity_v<::cuda::minimum<T>, void> = ::cuda::std::numeric_limits<T>::max();

template <typename T>
inline constexpr T identity_v<::cuda::maximum<>, T> = ::cuda::std::numeric_limits<T>::lowest();

template <typename T>
inline constexpr T identity_v<::cuda::maximum<T>, T> = ::cuda::std::numeric_limits<T>::lowest();

template <typename T>
inline constexpr T identity_v<::cuda::maximum<T>, void> = ::cuda::std::numeric_limits<T>::lowest();

template <typename T>
inline constexpr T identity_v<::cuda::std::plus<T>, T> = T{};

template <typename T>
inline constexpr T identity_v<::cuda::std::plus<>, T> = T{};

template <typename T>
inline constexpr T identity_v<::cuda::std::plus<T>, void> = T{};

template <typename T>
inline constexpr T identity_v<::cuda::std::bit_and<>, T> = static_cast<T>(~T{});

template <typename T>
inline constexpr T identity_v<::cuda::std::bit_and<T>, T> = static_cast<T>(~T{});

template <typename T>
inline constexpr T identity_v<::cuda::std::bit_and<T>, void> = static_cast<T>(~T{});

template <typename T>
inline constexpr T identity_v<::cuda::std::bit_or<>, T> = T{};

template <typename T>
inline constexpr T identity_v<::cuda::std::bit_or<T>, T> = T{};

template <typename T>
inline constexpr T identity_v<::cuda::std::bit_or<T>, void> = T{};

template <typename T>
inline constexpr T identity_v<::cuda::std::bit_xor<>, T> = T{};

template <typename T>
inline constexpr T identity_v<::cuda::std::bit_xor<T>, T> = T{};

template <typename T>
inline constexpr T identity_v<::cuda::std::bit_xor<T>, void> = T{};

template <typename T>
inline constexpr T identity_v<::cuda::std::logical_and<>, T> = true;

template <typename T>
inline constexpr T identity_v<::cuda::std::logical_and<T>, T> = true;

template <typename T>
inline constexpr T identity_v<::cuda::std::logical_and<T>, void> = true;

template <typename T>
inline constexpr T identity_v<::cuda::std::logical_or<>, T> = false;

template <typename T>
inline constexpr T identity_v<::cuda::std::logical_or<T>, T> = false;

template <typename T>
inline constexpr T identity_v<::cuda::std::logical_or<T>, void> = false;
} // namespace detail

#endif // !_CCCL_DOXYGEN_INVOKED

CUB_NAMESPACE_END
