aca-tasks/task1/include/mergesort_mt.h

117 lines
4.0 KiB
C
Raw Normal View History

2023-11-03 10:32:02 +00:00
#include <vector>
#include <span>
#include <thread>
#include <mutex>
2023-11-07 10:51:34 +00:00
#include <functional>
2023-11-03 10:32:02 +00:00
2023-11-09 20:26:07 +00:00
// General purpose mergesorter with multi threading support by Robin Dietzel <robin.dietzel@iem.thm.de>
2023-11-06 20:13:13 +00:00
template<typename T>
2023-11-03 10:32:02 +00:00
class MergeSorterMT {
2023-11-03 14:02:46 +00:00
public:
2023-11-06 20:13:13 +00:00
template<typename C>
2023-11-06 20:36:10 +00:00
MergeSorterMT(C cmp, int max_depth) : cmp(cmp), max_depth(max_depth) {
2023-11-09 20:26:07 +00:00
// Assert that cmp is a function that returns bool and takes two arguments of type T
2023-11-06 20:13:13 +00:00
static_assert(std::is_same<std::invoke_result_t<C, T, T>, bool>(), "C must be a function that returns a bool");
}
2023-11-03 14:02:46 +00:00
2023-11-09 20:26:07 +00:00
// Start sorting process
2023-11-06 20:36:10 +00:00
auto sort(std::vector<T> &data) -> void {
2023-11-09 20:26:07 +00:00
// Create span: like a 'view' on the vector -> no unnecessary copies are made when subdividing sorting problem
2023-11-06 20:36:10 +00:00
std::span<T> sortable(data);
2023-11-07 10:51:34 +00:00
split(sortable, 0, max_depth);
2023-11-06 20:36:10 +00:00
}
2023-11-03 14:02:46 +00:00
2023-11-06 20:36:10 +00:00
private:
2023-11-09 20:26:07 +00:00
// Merge function that merges left & right span into the output span
// No exclusive access on output is necessary (e.g. via mutex) because all parallel threads work on different parts of output
2023-11-03 10:32:02 +00:00
auto merge(std::span<T> &output, std::span<T> left, std::span<T> right) -> void {
2023-11-09 20:26:07 +00:00
// Create buffer, here we need a temporary container where we copy values to, because left and right are a view on parts
// of output
2023-11-03 10:32:02 +00:00
std::vector<T> buf;
2023-11-03 10:41:58 +00:00
buf.reserve(left.size() + right.size());
2023-11-03 10:32:02 +00:00
auto l = left.begin();
auto r = right.begin();
auto o = buf.begin();
2023-11-09 20:26:07 +00:00
// Insert from pre sorted half's
2023-11-03 10:41:58 +00:00
while (l < left.end() && r < right.end()) {
2023-11-03 10:32:02 +00:00
if (cmp(*l, *r)) {
buf.insert(o, *l);
l++;
} else {
buf.insert(o, *r);
r++;
}
o++;
}
2023-11-09 20:26:07 +00:00
// Fill up with rest of left values
2023-11-03 10:32:02 +00:00
while (l < left.end()) {
buf.insert(o, *l);
o++;
l++;
}
2023-11-09 20:26:07 +00:00
// Fill up with rest of right values
2023-11-03 10:32:02 +00:00
while (r < right.end()) {
buf.insert(o, *r);
o++;
r++;
}
2023-11-09 20:26:07 +00:00
// Completely move buffer to output
// IMPORTANT: left and right are still a view on the splitted output, that is now sorted
2023-11-07 10:51:34 +00:00
std::move(buf.begin(), buf.end(), output.begin());
2023-11-03 10:32:02 +00:00
}
2023-11-09 20:26:07 +00:00
// Splitup function
2023-11-07 10:51:34 +00:00
auto split(std::span<T> &data, int depth, const int &mdepth) -> void {
2023-11-09 20:26:07 +00:00
2023-11-06 20:36:10 +00:00
if (std::distance(data.begin(), data.end()) <= 1) {
2023-11-09 20:26:07 +00:00
// Quit if only one element 'insortable'
2023-11-03 14:02:46 +00:00
return;
2023-11-07 09:06:06 +00:00
} else if (std::distance(data.begin(), data.end()) == 2) {
2023-11-09 20:26:07 +00:00
// Swap two values dependant on size for small speedup (no call to further split must be made)
2023-11-07 09:06:06 +00:00
if(cmp(data[1], data[0])) {
std::swap(data[0], data[1]);
return;
}
2023-11-03 14:02:46 +00:00
}
2023-11-09 20:26:07 +00:00
// Determine mid of data
2023-11-03 14:02:46 +00:00
auto mid = data.begin();
2023-11-06 20:36:10 +00:00
std::advance(mid, std::distance(data.begin(), data.end()) / 2);
2023-11-03 14:02:46 +00:00
2023-11-09 20:26:07 +00:00
// Generate left and right view on data (no copies are made here)
2023-11-03 14:02:46 +00:00
std::span<T> left(data.begin(), mid);
std::span<T> right(mid, data.end());
2023-11-06 20:36:10 +00:00
if (depth < mdepth) {
2023-11-09 20:26:07 +00:00
// Create recursive split functions if maximum depth not reached
2023-11-07 10:51:34 +00:00
std::thread left_thread([&]() { split(left, depth + 1, mdepth); });
std::thread right_thread([&]() { split(right, depth + 1, mdepth); });
2023-11-03 14:02:46 +00:00
2023-11-09 20:26:07 +00:00
// Both threads must join before we could further work on the data viewed
// by left and right (recursively sorted by the both calls)
2023-11-03 14:02:46 +00:00
left_thread.join();
right_thread.join();
} else {
2023-11-09 20:26:07 +00:00
// Do normal recursion in a single thread if maximum depth is reached
2023-11-07 10:51:34 +00:00
split(left, depth + 1, mdepth);
split(right, depth + 1, mdepth);
2023-11-03 14:02:46 +00:00
}
2023-11-03 10:41:58 +00:00
2023-11-09 20:26:07 +00:00
// Merge left and right together before returning
2023-11-03 14:02:46 +00:00
merge(data, left, right);
2023-11-09 20:26:07 +00:00
return;
2023-11-03 10:41:58 +00:00
}
2023-11-06 20:36:10 +00:00
private:
2023-11-09 20:26:07 +00:00
// Templated comparator function
2023-11-06 20:36:10 +00:00
std::function<bool(T, T)> cmp;
2023-11-09 20:26:07 +00:00
// Maximum depth
2023-11-06 20:36:10 +00:00
const int max_depth;
2023-11-03 10:32:02 +00:00
};