#include #include #include #include template class MergeSorterMT { public: template MergeSorterMT(C cmp, int max_depth) : cmp(cmp), max_depth(max_depth) { static_assert(std::is_same, bool>(), "C must be a function that returns a bool"); } auto sort(std::vector &data) -> void { std::span sortable(data); split(sortable, 0, max_depth, mut); } private: auto merge(std::span &output, std::span left, std::span right) -> void { std::vector buf; buf.reserve(left.size() + right.size()); auto l = left.begin(); auto r = right.begin(); auto o = buf.begin(); while (l < left.end() && r < right.end()) { if (cmp(*l, *r)) { buf.insert(o, *l); l++; } else { buf.insert(o, *r); r++; } o++; } while (l < left.end()) { buf.insert(o, *l); o++; l++; } while (r < right.end()) { buf.insert(o, *r); o++; r++; } { //todo: is a lock guard necessary? //std::lock_guard lock(mut); std::move(buf.begin(), buf.end(), output.begin()); } } auto split(std::span &data, int depth, const int &mdepth, std::recursive_mutex &mutex) -> void { if (std::distance(data.begin(), data.end()) <= 1) { return; } else if (std::distance(data.begin(), data.end()) == 2) { if(cmp(data[1], data[0])) { std::swap(data[0], data[1]); return; } } auto mid = data.begin(); std::advance(mid, std::distance(data.begin(), data.end()) / 2); std::span left(data.begin(), mid); std::span right(mid, data.end()); if (depth < mdepth) { //todo: fix lambda call //std::thread left_thread(&MergeSorterMT::split, this, left, depth + 1, mdepth, mutex); //std::thread right_thread(&MergeSorterMT::split, this, right, depth + 1, mdepth, mutex); std::thread left_thread([&]() { split(left, depth + 1, mdepth, mutex); }); std::thread right_thread([&]() { split(right, depth + 1, mdepth, mutex); }); left_thread.join(); right_thread.join(); } else { split(left, depth + 1, mdepth, mutex); split(right, depth + 1, mdepth, mutex); } merge(data, left, right); } private: std::function cmp; const int max_depth; std::recursive_mutex mut; };