Program Listing for File cufft.hpp
↰ Return to documentation for file (include/cudawrappers/cufft.hpp)
#if !defined CUFFT_H
#define CUFFT_H
#if defined(__HIP__)
#include <hip/hip_fp16.h>
#include <hipfft/hipfft.h>
#include <hipfft/hipfftXt.h>
#else
#include <cuda_fp16.h>
#include <cufft.h>
#include <cufftXt.h>
#endif
#include <array>
#include <exception>
#include <magic_enum/magic_enum.hpp>
#include "cudawrappers/cu.hpp"
/*
* Error handling helper function
*/
static std::string _cudaGetErrorEnum(cufftResult_t error) {
return std::string(magic_enum::enum_name(error));
}
namespace cufft {
/*
* Error
*/
class Error : public std::exception {
public:
explicit Error(cufftResult result) : result_(result) {}
const char *what() const noexcept override {
message_ = _cudaGetErrorEnum(result_);
return message_.c_str();
}
operator cufftResult() const { return result_; }
private:
cufftResult result_;
mutable std::string message_ = "";
};
/*
* FFT
*/
class FFT {
public:
FFT() = default;
FFT &operator=(FFT &) = delete;
FFT(FFT &) = delete;
FFT &operator=(FFT &&other) noexcept {
if (&other != this) {
plan_ = other.plan_;
other.plan_ = 0;
}
return *this;
}
FFT(FFT &&other) noexcept { *this = std::move(other); }
~FFT() { checkCuFFTCall(cufftDestroy(plan_)); }
void setStream(cu::Stream &stream) const {
checkCuFFTCall(cufftSetStream(plan_, stream));
}
void execute(cu::DeviceMemory &in, cu::DeviceMemory &out,
const int direction) const {
void *in_ptr = reinterpret_cast<void *>(static_cast<CUdeviceptr>(in));
void *out_ptr = reinterpret_cast<void *>(static_cast<CUdeviceptr>(out));
checkCuFFTCall(cufftXtExec(plan_, in_ptr, out_ptr, direction));
}
protected:
void checkCuFFTCall(cufftResult result) const {
if (result != CUFFT_SUCCESS) {
throw Error(result);
}
}
cufftHandle *plan() { return &plan_; }
private:
cufftHandle plan_{};
};
/*
* FFT1D
*/
template <cudaDataType_t T>
class FFT1D : public FFT {
public:
#if defined(__HIP__)
__host__
#endif
FFT1D(const int nx) = delete;
#if defined(__HIP__)
__host__
#endif
FFT1D(const int nx, const int batch) = delete;
};
template <>
inline FFT1D<CUDA_C_32F>::FFT1D(const int nx, const int batch) {
checkCuFFTCall(cufftCreate(plan()));
checkCuFFTCall(cufftPlan1d(plan(), nx, CUFFT_C2C, batch));
}
template <>
inline FFT1D<CUDA_C_32F>::FFT1D(const int nx) : FFT1D(nx, 1) {}
template <>
inline FFT1D<CUDA_C_16F>::FFT1D(const int nx, const int batch) {
checkCuFFTCall(cufftCreate(plan()));
const int rank = 1;
size_t ws = 0;
std::array<long long, 1> n{nx};
const long long idist = 1;
const long long odist = 1;
const int istride = 1;
const int ostride = 1;
checkCuFFTCall(cufftXtMakePlanMany(*plan(), rank, n.data(), nullptr, istride,
idist, CUDA_C_16F, nullptr, ostride, odist,
CUDA_C_16F, batch, &ws, CUDA_C_16F));
}
template <>
inline FFT1D<CUDA_C_16F>::FFT1D(const int nx) : FFT1D(nx, 1) {}
/*
* FFT2D
*/
template <cudaDataType_t T>
class FFT2D : public FFT {
public:
#if defined(__HIP__)
__host__
#endif
FFT2D(const int nx, const int ny) = delete;
#if defined(__HIP__)
__host__
#endif
FFT2D(const int nx, const int ny, const int stride, const int dist,
const int batch) = delete;
};
template <>
inline FFT2D<CUDA_C_32F>::FFT2D(const int nx, const int ny) {
checkCuFFTCall(cufftCreate(plan()));
checkCuFFTCall(cufftPlan2d(plan(), nx, ny, CUFFT_C2C));
}
template <>
inline FFT2D<CUDA_C_32F>::FFT2D(const int nx, const int ny, const int stride,
const int dist, const int batch) {
checkCuFFTCall(cufftCreate(plan()));
std::array<int, 2> n{nx, ny};
checkCuFFTCall(cufftPlanMany(plan(), 2, n.data(), n.data(), stride, dist,
n.data(), stride, dist, CUFFT_C2C, batch));
}
template <>
inline FFT2D<CUDA_C_16F>::FFT2D(const int nx, const int ny, const int stride,
const int dist, const int batch) {
checkCuFFTCall(cufftCreate(plan()));
const int rank = 2;
size_t ws = 0;
std::array<long long, 2> n{nx, ny};
const int istride = stride;
const int ostride = stride;
const long long int idist = dist;
const long long int odist = dist;
checkCuFFTCall(cufftXtMakePlanMany(*plan(), rank, n.data(), nullptr, istride,
idist, CUDA_C_16F, nullptr, ostride, odist,
CUDA_C_16F, batch, &ws, CUDA_C_16F));
}
template <>
inline FFT2D<CUDA_C_16F>::FFT2D(const int nx, const int ny)
: FFT2D(nx, ny, 1, nx * ny, 1) {}
/*
* FFT1DR2C
*/
template <cudaDataType_t T>
class FFT1DR2C : public FFT {
public:
#if defined(__HIP__)
__host__
#endif
FFT1DR2C(const int nx) = delete;
#if defined(__HIP__)
__host__
#endif
FFT1DR2C(const int nx, const int batch) = delete;
#if defined(__HIP__)
__host__
#endif
FFT1DR2C(const int nx, const int batch, long long inembed,
long long ouembed) = delete;
};
template <>
inline FFT1DR2C<CUDA_R_32F>::FFT1DR2C(const int nx, const int batch,
long long inembed, long long ouembed) {
checkCuFFTCall(cufftCreate(plan()));
const int rank = 1;
size_t ws = 0;
std::array<long long, 1> n{nx};
const long long idist = inembed;
const long long odist = ouembed;
const long long istride = 1;
const long long ostride = 1;
checkCuFFTCall(cufftXtMakePlanMany(
*plan(), rank, n.data(), &inembed, istride, idist, CUDA_R_32F, &ouembed,
ostride, odist, CUDA_C_32F, batch, &ws, CUDA_C_32F));
}
/*
* FFT1D_C2R
*/
template <cudaDataType_t T>
class FFT1DC2R : public FFT {
public:
#if defined(__HIP__)
__host__
#endif
FFT1DC2R(const int nx) = delete;
#if defined(__HIP__)
__host__
#endif
FFT1DC2R(const int nx, const int batch) = delete;
#if defined(__HIP__)
__host__
#endif
FFT1DC2R(const int nx, const int batch, long long inembed,
long long ouembed) = delete;
};
template <>
inline FFT1DC2R<CUDA_C_32F>::FFT1DC2R(const int nx, const int batch,
long long inembed, long long ouembed) {
checkCuFFTCall(cufftCreate(plan()));
const int rank = 1;
size_t ws = 0;
std::array<long long, 1> n{nx};
const long long idist = inembed;
const long long odist = ouembed;
const int istride = 1;
const int ostride = 1;
checkCuFFTCall(cufftXtMakePlanMany(
*plan(), rank, n.data(), &inembed, istride, idist, CUDA_C_32F, &ouembed,
ostride, odist, CUDA_R_32F, batch, &ws, CUDA_C_32F));
}
} // namespace cufft
#endif // CUFFT_H