#include <opencv2/opencv.hpp>
#include <cassert>
#include <numeric>
+#include <mutex>
+#include <stack>
#if defined(CUFFT) || defined(CUFFTW)
#include "cuda_runtime.h"
#endif
#endif
+class MemoryManager {
+ std::mutex mutex;
+ std::map<size_t, std::stack<void*> > map;
+
+public:
+ void *get(size_t size) {
+ std::lock_guard<std::mutex> guard(mutex);
+ auto &stack = map[size];
+ void *ptr = nullptr;
+ if (!stack.empty()) {
+ ptr = stack.top();
+ stack.pop();
+ }
+ return ptr;
+ }
+ void put(void *ptr, size_t size) {
+ std::lock_guard<std::mutex> guard(mutex);
+ map[size].push(ptr);
+ }
+};
+
template <typename T> class DynMem_ {
private:
T *ptr_h = nullptr;
#ifdef CUFFT
T *ptr_d = nullptr;
+ static MemoryManager mmng;
#endif
public:
typedef T value_type;
DynMem_(size_t num_elem) : num_elem(num_elem)
{
#ifdef CUFFT
- CudaSafeCall(cudaHostAlloc(reinterpret_cast<void **>(&ptr_h), num_elem * sizeof(T), cudaHostAllocMapped));
+ ptr_h = reinterpret_cast<T*>(mmng.get(num_elem));
+ if (!ptr_h) {
+ printf("malloc(%zu)\n", num_elem);
+ CudaSafeCall(cudaHostAlloc(reinterpret_cast<void **>(&ptr_h), num_elem * sizeof(T), cudaHostAllocMapped));
+ }
CudaSafeCall(cudaHostGetDevicePointer(reinterpret_cast<void **>(&ptr_d), reinterpret_cast<void *>(ptr_h), 0));
#else
ptr_h = new T[num_elem];
void release()
{
#ifdef CUFFT
- CudaSafeCall(cudaFreeHost(ptr_h));
+ if (ptr_h)
+ mmng.put(ptr_h, num_elem);
+ //CudaSafeCall(cudaFreeHost(ptr_h));
#else
delete[] ptr_h;
#endif
}
};
+#ifdef CUFFT
+template <typename T>
+MemoryManager DynMem_<T>::mmng;
+#endif
+
typedef DynMem_<float> DynMem;