#include #include #include #include #include template class MergeSorterMT { public: template MergeSorterMT(C cmp, int max_depth) : cmp(cmp), max_depth(max_depth) { static_assert(std::is_same, bool>(), "C must be a function that returns a bool"); } auto sort(std::vector &data) -> void { std::span sortable(data); split(sortable, 0, max_depth); } private: auto merge(std::span &output, std::span left, std::span right) -> void { std::vector buf; buf.reserve(left.size() + right.size()); auto l = left.begin(); auto r = right.begin(); auto o = buf.begin(); while (l < left.end() && r < right.end()) { if (cmp(*l, *r)) { buf.insert(o, *l); l++; } else { buf.insert(o, *r); r++; } o++; } while (l < left.end()) { buf.insert(o, *l); o++; l++; } while (r < right.end()) { buf.insert(o, *r); o++; r++; } std::move(buf.begin(), buf.end(), output.begin()); } auto split(std::span &data, int depth, const int &mdepth) -> void { if (std::distance(data.begin(), data.end()) <= 1) { return; } else if (std::distance(data.begin(), data.end()) == 2) { if(cmp(data[1], data[0])) { std::swap(data[0], data[1]); return; } } auto mid = data.begin(); std::advance(mid, std::distance(data.begin(), data.end()) / 2); std::span left(data.begin(), mid); std::span right(mid, data.end()); if (depth < mdepth) { std::thread left_thread([&]() { split(left, depth + 1, mdepth); }); std::thread right_thread([&]() { split(right, depth + 1, mdepth); }); left_thread.join(); right_thread.join(); } else { split(left, depth + 1, mdepth); split(right, depth + 1, mdepth); } merge(data, left, right); } private: std::function cmp; const int max_depth; };