aca-tasks/task1/mergesort_mt.h

90 lines
2.5 KiB
C++

#include <vector>
#include <span>
#include <thread>
#include <mutex>
template<typename T>
class MergeSorterMT {
public:
template<typename C>
MergeSorterMT(C cmp, int max_depth) : cmp(cmp), max_depth(max_depth) {
static_assert(std::is_same<std::invoke_result_t<C, T, T>, bool>(), "C must be a function that returns a bool");
}
auto sort(std::vector<T> &data) -> void {
std::span<T> sortable(data);
split(sortable, 0, max_depth, mut);
}
private:
auto merge(std::span<T> &output, std::span<T> left, std::span<T> right) -> void {
std::vector<T> buf;
buf.reserve(left.size() + right.size());
auto l = left.begin();
auto r = right.begin();
auto o = buf.begin();
while (l < left.end() && r < right.end()) {
if (cmp(*l, *r)) {
buf.insert(o, *l);
l++;
} else {
buf.insert(o, *r);
r++;
}
o++;
}
while (l < left.end()) {
buf.insert(o, *l);
o++;
l++;
}
while (r < right.end()) {
buf.insert(o, *r);
o++;
r++;
}
{
//todo: is a lock guard necessary?
//std::lock_guard<std::recursive_mutex> lock(mut);
std::move(buf.begin(), buf.end(), output.begin());
}
}
auto split(std::span<T> &data, int depth, const int &mdepth, std::recursive_mutex &mutex) -> void {
if (std::distance(data.begin(), data.end()) <= 1) {
return;
}
auto mid = data.begin();
std::advance(mid, std::distance(data.begin(), data.end()) / 2);
std::span<T> left(data.begin(), mid);
std::span<T> right(mid, data.end());
if (depth < mdepth) {
//todo: fix lambda call
//std::thread left_thread(&MergeSorterMT::split, this, left, depth + 1, mdepth, mutex);
//std::thread right_thread(&MergeSorterMT::split, this, right, depth + 1, mdepth, mutex);
std::thread left_thread([&]() { split(left, depth + 1, mdepth, mutex); });
std::thread right_thread([&]() { split(right, depth + 1, mdepth, mutex); });
left_thread.join();
right_thread.join();
} else {
split(left, depth + 1, mdepth, mutex);
split(right, depth + 1, mdepth, mutex);
}
merge(data, left, right);
}
private:
std::function<bool(T, T)> cmp;
const int max_depth;
std::recursive_mutex mut;
};