Program Listing for File cufft.hpp

Return to documentation for file (include/cudawrappers/cufft.hpp)

#if !defined CUFFT_H
#define CUFFT_H

#include <cuda_fp16.h>
#include <cufft.h>
#include <cufftXt.h>

#include <exception>

#include "cudawrappers/cu.hpp"

/*
 * Error handling helper function, copied from cuda-samples Common/helper_cuda.h
 */
static const char *_cudaGetErrorEnum(cufftResult error) {
  switch (error) {
    case CUFFT_SUCCESS:
      return "CUFFT_SUCCESS";

    case CUFFT_INVALID_PLAN:
      return "CUFFT_INVALID_PLAN";

    case CUFFT_ALLOC_FAILED:
      return "CUFFT_ALLOC_FAILED";

    case CUFFT_INVALID_TYPE:
      return "CUFFT_INVALID_TYPE";

    case CUFFT_INVALID_VALUE:
      return "CUFFT_INVALID_VALUE";

    case CUFFT_INTERNAL_ERROR:
      return "CUFFT_INTERNAL_ERROR";

    case CUFFT_EXEC_FAILED:
      return "CUFFT_EXEC_FAILED";

    case CUFFT_SETUP_FAILED:
      return "CUFFT_SETUP_FAILED";

    case CUFFT_INVALID_SIZE:
      return "CUFFT_INVALID_SIZE";

    case CUFFT_UNALIGNED_DATA:
      return "CUFFT_UNALIGNED_DATA";

    case CUFFT_INCOMPLETE_PARAMETER_LIST:
      return "CUFFT_INCOMPLETE_PARAMETER_LIST";

    case CUFFT_INVALID_DEVICE:
      return "CUFFT_INVALID_DEVICE";

    case CUFFT_PARSE_ERROR:
      return "CUFFT_PARSE_ERROR";

    case CUFFT_NO_WORKSPACE:
      return "CUFFT_NO_WORKSPACE";

    case CUFFT_NOT_IMPLEMENTED:
      return "CUFFT_NOT_IMPLEMENTED";

    case CUFFT_LICENSE_ERROR:
      return "CUFFT_LICENSE_ERROR";

    case CUFFT_NOT_SUPPORTED:
      return "CUFFT_NOT_SUPPORTED";
  }

  return "<unknown>";
}

namespace cufft {

/*
 * Error
 */
class Error : public std::exception {
 public:
  explicit Error(cufftResult result) : result_(result) {}

  const char *what() const noexcept override {
    return _cudaGetErrorEnum(result_);
  }

  operator cufftResult() const { return result_; }

 private:
  cufftResult result_;
};

/*
 * FFT
 */
class FFT {
 public:
  FFT() = default;
  FFT &operator=(FFT &) = delete;
  FFT(FFT &) = delete;
  FFT &operator=(FFT &&other) noexcept {
    if (&other != this) {
      plan_ = other.plan_;
      other.plan_ = 0;
    }
    return *this;
  }
  FFT(FFT &&other) noexcept { *this = std::move(other); }

  ~FFT() { checkCuFFTCall(cufftDestroy(plan_)); }

  void setStream(cu::Stream &stream) {
    checkCuFFTCall(cufftSetStream(plan_, stream));
  }

  void execute(cu::DeviceMemory &in, cu::DeviceMemory &out, int direction) {
    void *in_ptr = reinterpret_cast<void *>(static_cast<CUdeviceptr>(in));
    void *out_ptr = reinterpret_cast<void *>(static_cast<CUdeviceptr>(out));
    checkCuFFTCall(cufftXtExec(plan_, in_ptr, out_ptr, direction));
  }

 protected:
  void checkCuFFTCall(cufftResult result) {
    if (result != CUFFT_SUCCESS) {
      throw Error(result);
    }
  }

  cufftHandle *plan() { return &plan_; }

 private:
  cufftHandle plan_{};
};

/*
 * FFT1D
 */
template <cudaDataType_t T>
class FFT1D : public FFT {
 public:
  FFT1D(int nx) = delete;
  FFT1D(int nx, int batch) = delete;
};

template <>
FFT1D<CUDA_C_32F>::FFT1D(int nx, int batch) {
  checkCuFFTCall(cufftCreate(plan()));
  checkCuFFTCall(cufftPlan1d(plan(), nx, CUFFT_C2C, batch));
}

template <>
FFT1D<CUDA_C_32F>::FFT1D(int nx) : FFT1D(nx, 1) {}

template <>
FFT1D<CUDA_C_16F>::FFT1D(int nx, int batch) {
  checkCuFFTCall(cufftCreate(plan()));
  const int rank = 1;
  size_t ws = 0;
  std::array<long long, 1> n{nx};
  long long int idist = 1;
  long long int odist = 1;
  int istride = 1;
  int ostride = 1;
  checkCuFFTCall(cufftXtMakePlanMany(*plan(), rank, n.data(), nullptr, istride,
                                     idist, CUDA_C_16F, nullptr, ostride, odist,
                                     CUDA_C_16F, batch, &ws, CUDA_C_16F));
}

template <>
FFT1D<CUDA_C_16F>::FFT1D(int nx) : FFT1D(nx, 1) {}

/*
 * FFT2D
 */
template <cudaDataType_t T>
class FFT2D : public FFT {
 public:
  FFT2D(int nx, int ny) = delete;
  FFT2D(int nx, int ny, int stride, int dist, int batch) = delete;
};

template <>
FFT2D<CUDA_C_32F>::FFT2D(int nx, int ny) {
  checkCuFFTCall(cufftCreate(plan()));
  checkCuFFTCall(cufftPlan2d(plan(), nx, ny, CUFFT_C2C));
}

template <>
FFT2D<CUDA_C_32F>::FFT2D(int nx, int ny, int stride, int dist, int batch) {
  checkCuFFTCall(cufftCreate(plan()));
  std::array<int, 2> n{nx, ny};
  checkCuFFTCall(cufftPlanMany(plan(), 2, n.data(), n.data(), stride, dist,
                               n.data(), stride, dist, CUFFT_C2C, batch));
}

template <>
FFT2D<CUDA_C_16F>::FFT2D(int nx, int ny, int stride, int dist, int batch) {
  checkCuFFTCall(cufftCreate(plan()));
  const int rank = 2;
  size_t ws = 0;
  std::array<long long, 2> n{nx, ny};
  int istride = stride;
  int ostride = stride;
  long long int idist = dist;
  long long int odist = dist;
  checkCuFFTCall(cufftXtMakePlanMany(*plan(), rank, n.data(), nullptr, istride,
                                     idist, CUDA_C_16F, nullptr, ostride, odist,
                                     CUDA_C_16F, batch, &ws, CUDA_C_16F));
}

template <>
FFT2D<CUDA_C_16F>::FFT2D(int nx, int ny) : FFT2D(nx, ny, 1, nx * ny, 1) {}

}  // namespace cufft

#endif  // CUFFT_H