88 lines
2.3 KiB
C++
88 lines
2.3 KiB
C++
#include <vector>
|
|
#include <span>
|
|
#include <thread>
|
|
#include <mutex>
|
|
#include <functional>
|
|
|
|
template<typename T>
|
|
class MergeSorterMT {
|
|
|
|
public:
|
|
template<typename C>
|
|
MergeSorterMT(C cmp, int max_depth) : cmp(cmp), max_depth(max_depth) {
|
|
static_assert(std::is_same<std::invoke_result_t<C, T, T>, bool>(), "C must be a function that returns a bool");
|
|
}
|
|
|
|
auto sort(std::vector<T> &data) -> void {
|
|
std::span<T> sortable(data);
|
|
split(sortable, 0, max_depth);
|
|
}
|
|
|
|
private:
|
|
auto merge(std::span<T> &output, std::span<T> left, std::span<T> right) -> void {
|
|
std::vector<T> buf;
|
|
buf.reserve(left.size() + right.size());
|
|
|
|
auto l = left.begin();
|
|
auto r = right.begin();
|
|
auto o = buf.begin();
|
|
|
|
while (l < left.end() && r < right.end()) {
|
|
if (cmp(*l, *r)) {
|
|
buf.insert(o, *l);
|
|
l++;
|
|
} else {
|
|
buf.insert(o, *r);
|
|
r++;
|
|
}
|
|
o++;
|
|
}
|
|
while (l < left.end()) {
|
|
buf.insert(o, *l);
|
|
o++;
|
|
l++;
|
|
}
|
|
while (r < right.end()) {
|
|
buf.insert(o, *r);
|
|
o++;
|
|
r++;
|
|
}
|
|
|
|
std::move(buf.begin(), buf.end(), output.begin());
|
|
}
|
|
|
|
auto split(std::span<T> &data, int depth, const int &mdepth) -> void {
|
|
if (std::distance(data.begin(), data.end()) <= 1) {
|
|
return;
|
|
} else if (std::distance(data.begin(), data.end()) == 2) {
|
|
if(cmp(data[1], data[0])) {
|
|
std::swap(data[0], data[1]);
|
|
return;
|
|
}
|
|
}
|
|
|
|
auto mid = data.begin();
|
|
std::advance(mid, std::distance(data.begin(), data.end()) / 2);
|
|
|
|
std::span<T> left(data.begin(), mid);
|
|
std::span<T> right(mid, data.end());
|
|
|
|
if (depth < mdepth) {
|
|
std::thread left_thread([&]() { split(left, depth + 1, mdepth); });
|
|
std::thread right_thread([&]() { split(right, depth + 1, mdepth); });
|
|
|
|
left_thread.join();
|
|
right_thread.join();
|
|
} else {
|
|
split(left, depth + 1, mdepth);
|
|
split(right, depth + 1, mdepth);
|
|
}
|
|
|
|
merge(data, left, right);
|
|
}
|
|
|
|
|
|
private:
|
|
std::function<bool(T, T)> cmp;
|
|
const int max_depth;
|
|
}; |