fix errors

This commit is contained in:
Robin Dietzel 2023-11-06 21:13:13 +01:00
parent b1e7903bbd
commit 510d1e4e4b
2 changed files with 14 additions and 9 deletions

View File

@ -67,16 +67,16 @@ auto main(int argc, char *argv[]) -> int {
//const int max_depth = std::thread::hardware_concurrency(); //const int max_depth = std::thread::hardware_concurrency();
const int max_depth = 4; const int max_depth = 4;
t1 = std::chrono::high_resolution_clock::now(); t1 = std::chrono::high_resolution_clock::now();
MergeSorterMT<int, bool(*)(int, int)> ms([](int a, int b) { MergeSorterMT<int> ms([](int a, int b) {
return (a>b); return (a>b);
}); });
std::span t = dataset_par; std::span t = dataset_par;
int mdepth = 4; int mdepth = 4;
std::recursive_mutex mut; std::recursive_mutex mut;
ms.split(t, 0, mdepth, mut); ms.split(t, 0, mdepth, mut);
algo::MergeSort_mt::sort(dataset_par, [](int32_t a, int32_t b) { // algo::MergeSort_mt::sort(dataset_par, [](int32_t a, int32_t b) {
return (a > b); // return (a > b);
}, max_depth); // }, max_depth);
t2 = std::chrono::high_resolution_clock::now(); t2 = std::chrono::high_resolution_clock::now();
delay_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1); delay_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1);

View File

@ -3,14 +3,17 @@
#include <thread> #include <thread>
#include <mutex> #include <mutex>
template<typename T, typename C> template<typename T>
class MergeSorterMT { class MergeSorterMT {
public: public:
MergeSorterMT(C cmp) : cmp(cmp){} template<typename C>
MergeSorterMT(C cmp) : cmp(cmp){
static_assert(std::is_same<std::invoke_result_t<C, T, T>, bool>(), "C must be a function that returns a bool");
}
C cmp; std::function<bool(T, T)> cmp;
std::recursive_mutex mut; std::recursive_mutex mut;
auto merge(std::span<T> &output, std::span<T> left, std::span<T> right) -> void { auto merge(std::span<T> &output, std::span<T> left, std::span<T> right) -> void {
@ -60,8 +63,10 @@ public:
std::span<T> right(mid, data.end()); std::span<T> right(mid, data.end());
if(depth < max_depth) { if(depth < max_depth) {
std::thread left_thread(&MergeSorterMT<T, C>::split, this, left, depth + 1, max_depth, mut); //std::thread left_thread(&MergeSorterMT::split, this, left, depth + 1, max_depth, mut);
std::thread right_thread(&MergeSorterMT<T, C>::split, this, right, depth + 1, max_depth, mut); //std::thread right_thread(&MergeSorterMT::split, this, right, depth + 1, max_depth, mut);
std::thread left_thread([&](){split(left, depth + 1, max_depth, mut);});
std::thread right_thread([&](){split(right, depth + 1, max_depth, mut);});
left_thread.join(); left_thread.join();
right_thread.join(); right_thread.join();