diff --git a/task1/main.cpp b/task1/main.cpp index 7b770d7..a2fdad9 100644 --- a/task1/main.cpp +++ b/task1/main.cpp @@ -67,6 +67,13 @@ auto main(int argc, char *argv[]) -> int { //const int max_depth = std::thread::hardware_concurrency(); const int max_depth = 4; t1 = std::chrono::high_resolution_clock::now(); + MergeSorterMT ms([](int a, int b) { + return (a>b); + }); + std::span t = dataset_par; + int mdepth = 4; + std::recursive_mutex mut; + ms.split(t, 0, mdepth, mut); algo::MergeSort_mt::sort(dataset_par, [](int32_t a, int32_t b) { return (a > b); }, max_depth); diff --git a/task1/mergesort_mt.h b/task1/mergesort_mt.h index 51fc7a1..f88ff4b 100644 --- a/task1/mergesort_mt.h +++ b/task1/mergesort_mt.h @@ -6,7 +6,11 @@ template class MergeSorterMT { - C comp; +public: + MergeSorterMT(C cmp) : cmp(cmp){} + + + C cmp; std::recursive_mutex mut; auto merge(std::span &output, std::span left, std::span right) -> void { @@ -44,8 +48,29 @@ class MergeSorterMT { } } - auto split(std::vector &data, C cmp, int depth, int &max_depth, std::mutex &mut) -> void { + auto split(std::span &data, int depth, int &max_depth, std::recursive_mutex &mut) -> void { + if(std::distance(data.begin(), data.end()) <= 1) { + return; + } + auto mid = data.begin(); + std::advance(mid, std::distance(data.begin(), data.end())/2); + + std::span left(data.begin(), mid); + std::span right(mid, data.end()); + + if(depth < max_depth) { + std::thread left_thread(&MergeSorterMT::split, this, left, depth + 1, max_depth, mut); + std::thread right_thread(&MergeSorterMT::split, this, right, depth + 1, max_depth, mut); + + left_thread.join(); + right_thread.join(); + } else { + split(left, depth + 1, max_depth, mut); + split(right, depth + 1, max_depth, mut); + } + + merge(data, left, right); }